Как исправить AttentionWrapperState в TensorFlow - PullRequest
0 голосов
/ 05 декабря 2018

Итак, я использую пример модели seq2seq, и она была сделана в TensorFlow 1.1.0, поэтому появился класс с именем DynamicAttentionWrapperState, который теперь был заменен (я предполагаю) просто AttentionWrapperState.Я не могу просто понизить версию до н.э. Я использую Google Colab, поэтому мне нужно обновить код для последней версии (1.12.0).Когда я просто изменяю код для удаления части строки Dynamic, я получаю эту ошибку:

TypeError: __new__() missing 4 required positional arguments: 'time', 'alignments', 'alignment_history', and 'attention_state'

Я рассуждал, потому что новый класс требует больше параметров, чем старый.К сожалению, я очень новичок в этом, поэтому я не знаю, что добавить в каждый из аргументов.Ниже мой код.Спасибо, что помогли мне разобраться!

def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state, vocab_size, text_length, summary_length, 
                   max_summary_length, rnn_size, vocab_to_int, keep_prob, batch_size, num_layers):
    '''Create the decoding cell and attention for the training and inference decoding layers'''

    for layer in range(num_layers):
        with tf.variable_scope('decoder_{}'.format(layer)):
            lstm = tf.contrib.rnn.LSTMCell(rnn_size,
                                           initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
            dec_cell = tf.contrib.rnn.DropoutWrapper(lstm, 
                                                     input_keep_prob = keep_prob)

    output_layer = Dense(vocab_size,
                         kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))

    attn_mech = tf.contrib.seq2seq.BahdanauAttention(rnn_size,
                                                  enc_output,
                                                  text_length,
                                                  normalize=False,
                                                  name='BahdanauAttention')

    dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell,
                                                          attn_mech,
                                                          rnn_size)

    initial_state = tf.contrib.seq2seq.AttentionWrapperState(enc_state[0],_zero_state_tensors(rnn_size, batch_size, tf.float32)) 
    with tf.variable_scope("decode"):
        training_logits = training_decoding_layer(dec_embed_input, 
                                                  summary_length, 
                                                  dec_cell, 
                                                  initial_state,
                                                  output_layer,
                                                  vocab_size, 
                                                  max_summary_length)
    with tf.variable_scope("decode", reuse=True):
        inference_logits = inference_decoding_layer(embeddings,  
                                                    vocab_to_int['<GO>'], 
                                                    vocab_to_int['<EOS>'],
                                                    dec_cell, 
                                                    initial_state, 
                                                    output_layer,
                                                    max_summary_length,
                                                    batch_size)

    return training_logits, inference_logits
...