Получить контекстные векторы из AttentionWrapper - PullRequest
0 голосов
/ 11 февраля 2019

Мне нужно извлечь контекстные векторы из механизма внимания, примененного к модели Seq2Seq.

Моим первым предположением было то, что я могу найти их в выходных данных динамического_кода

decoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(self.num_units, self.dropout) for _ in range(self.num_layers)])

attention_mechanism = tf.contrib.seq2seq.LuongAttention(self.num_units_attention, encoder_output, memory_sequence_length=x_lengths)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=self.num_units_attention, alignment_history=True)

decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state, output_layer=self.projection_layer)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,maximum_iterations=maximum_iterations)

context_vectors = outputs.rnn_outputs

Но я понимаю, что это вычисленные векторы внимания, а не контекстные векторы.

Где векторы контекста хранятся в AttentionWrapper?

...