Функция встраивания в поисковый декодер луча - PullRequest
0 голосов
/ 19 февраля 2020

Я новичок в глубоких нейронных сетях и, следовательно, в луч декодера поиска. Я пытаюсь реализовать указатель сети с декодером поиска луча на уже существующий код. Я не могу понять, как работает функция встраивания в декодер лучевого поиска. Моя функция вложения передает значения вложения следующих входных данных, которые были предсказаны. Но когда я пытаюсь проверить доступ к функции встраивания, просто добавив в нее оператор print, я вижу, что функция встраивания используется только 2 раза. Я прошел через класс tf.contrib.seq2seq.BeamSearchDecoder, но не могу понять, как он получает все значения встраивания, просто дважды обращаясь к функции встраивания. Я делюсь фрагментом кода. Любая помощь будет высоко ценится. Спасибо

def embedding_lookup(ids):
    # Note the output value of the decoder only ranges 0 to max_input_sequence_len
    # while embedding_table contains two more tokens' values 
    # To get around this, shift ids
    # Shape: [batch_size,beam_width] 
    ids = ids+2
    print("check access")
    # Shape: [batch_size,beam_width,vocab_size] 128,2,8
    one_hot_ids_base = tf.cast(tf.one_hot(ids,self.vocab_size), dtype=tf.float32)
    # Shape: [batch_size,beam_width,vocab_size,1] 128,2,8,1
    one_hot_ids_base = tf.expand_dims(one_hot_ids_base,-1)
    # Shape: [batch_size,beam_width,features_size]
    next_inputs = tf.reduce_sum(one_hot_ids_base*tile_embedding_table, axis=2)  
    user = user + 1
    return next_inputs
  # Do a little trick so that we can use 'BeamSearchDecoder'
  shifted_START_ID = START_ID - 2
  print("Shifted start id", shifted_START_ID)
  shifted_END_ID = END_ID - 2
  print("Shifted end", shifted_END_ID)
  # Beam Search Decoder
  decoder = tf.contrib.seq2seq.BeamSearchDecoder(dec_cell, embedding_lookup, 
                                      tf.tile([shifted_START_ID],[self.batch_size]), shifted_END_ID, 
                                      dec_cell.zero_state(self.batch_size*beam_width,tf.float32), beam_width)
  # Decode
  outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
  # predicted_ids
  # Shape: [batch_size, max_output_sequence_len,  beam_width]
  predicted_ids = outputs.predicted_ids
...