Форма предикатов в выходных данных `tf.contrib.seq2seq.BeamSearchDecoder` - PullRequest
0 голосов
/ 24 мая 2018

Какова форма содержимого в outputs из tf.contrib.seq2seq.BeamSearchDecoder.Я знаю, что это экземпляр class BeamSearchDecoderOutput(scores, predicted_ids, parent_ids), но какова форма scores, predicted_ids и parent_ids?

1 Ответ

0 голосов
/ 25 мая 2018

Я сам написал следующий игрушечный код, чтобы немного его изучить.

tgt_vocab_size = 20
embedding_decoder = tf.one_hot(list(range(0, tgt_vocab_size)), tgt_vocab_size)
batch_size = 2
start_tokens = tf.fill([batch_size], 0)
end_token = 1
beam_width = 3
num_units=18

decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
encoder_outputs = decoder_cell.zero_state(batch_size, dtype=tf.float32)
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)

my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell,
                                                  embedding=embedding_decoder,
                                                  start_tokens=start_tokens,
                                                  end_token=end_token,
                                                  initial_state=tiled_encoder_outputs,
                                                  beam_width=beam_width)

 # dynamic decoding
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder,
                                                                   maximum_iterations=4,
                                                                   output_time_major=True)
final_predicted_ids = outputs.predicted_ids
scores = outputs.beam_search_decoder_output.scores
predicted_ids = outputs.beam_search_decoder_output.predicted_ids
parent_ids = outputs.beam_search_decoder_output.parent_ids

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    final_predicted_ids_vals = sess.run(final_predicted_ids)
    print("final_predicted_ids shape:")
    print(final_predicted_ids_vals.shape)
    print("final_predicted_ids_vals: \n%s" %final_predicted_ids_vals)
    print("scores shape:")
    print(sess.run(scores).shape)
    print("scores values: \n %s" % sess.run(scores))
    print("predicted_ids shape: ")
    print(sess.run(predicted_ids).shape)
    print("predicted_ids values: \n %s" % sess.run(predicted_ids))
    print("parent_ids shape:")
    print(sess.run(parent_ids).shape)
    print("parent_ids values: \n %s" % sess.run(parent_ids))

Печать выглядит следующим образом:

final_predicted_ids shape:
(4, 2, 3)
final_predicted_ids_vals: 
[[[ 1  8  8]
  [ 1  8  8]]

 [[ 1 13 13]
  [ 1 13 13]]

 [[ 1 13 13]
  [ 1 13 13]]

 [[ 1 13  2]
  [ 1 13  2]]]
scores shape:
(4, 2, 3)
scores values: 
 [[[ -2.8376358  -2.843168   -2.8478816]
  [ -2.8376358  -2.843168   -2.8478816]]

 [[ -2.8478816  -5.655898   -5.6810265]
  [ -2.8478816  -5.655898   -5.6810265]]

 [[ -2.8478816  -8.478384   -8.495466 ]
  [ -2.8478816  -8.478384   -8.495466 ]]

 [[ -2.8478816 -11.292251  -11.307263 ]
  [ -2.8478816 -11.292251  -11.307263 ]]]
predicted_ids shape: 
(4, 2, 3)
predicted_ids values: 
 [[[ 8 13  1]
  [ 8 13  1]]

 [[ 1 13 13]
  [ 1 13 13]]

 [[ 1 13 12]
  [ 1 13 12]]

 [[ 1 13  2]
  [ 1 13  2]]]
parent_ids shape:
(4, 2, 3)
parent_ids values: 
 [[[0 0 0]
  [0 0 0]]

 [[2 0 1]
  [2 0 1]]

 [[0 1 1]
  [0 1 1]]

 [[0 1 1]
  [0 1 1]]]

outputs из tf.contrib.seq2seq.dynamic_decode(BeamSearchDecoder) на самом делеэкземпляр class FinalBeamSearchDecoderOutput, который состоит из:

predicted_ids: конечные выходные данные, возвращаемые при поиске луча после завершения всего декодирования.Тензор формы [batch_size, num_steps, beam_width] (или [num_steps, batch_size, beam_width], если output_time_major равно True).Лучи упорядочены от лучшего к худшему.

beam_search_decoder_output: экземпляр BeamSearchDecoderOutput, который описывает состояние поиска луча.

Поэтому необходимо убедиться, что окончательные прогнозы / переводы имеют форму[beam_width, batch_size, num_steps] на transpose([2, 0, 1]) или tf.transpose(final_predicted_ids), если output_time_major=True.

...