Итак, я использую пример модели seq2seq, и она была сделана в TensorFlow 1.1.0, поэтому появился класс с именем DynamicAttentionWrapperState
, который теперь был заменен (я предполагаю) просто AttentionWrapperState
.Я не могу просто понизить версию до н.э. Я использую Google Colab, поэтому мне нужно обновить код для последней версии (1.12.0).Когда я просто изменяю код для удаления части строки Dynamic
, я получаю эту ошибку:
TypeError: __new__() missing 4 required positional arguments: 'time', 'alignments', 'alignment_history', and 'attention_state'
Я рассуждал, потому что новый класс требует больше параметров, чем старый.К сожалению, я очень новичок в этом, поэтому я не знаю, что добавить в каждый из аргументов.Ниже мой код.Спасибо, что помогли мне разобраться!
def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state, vocab_size, text_length, summary_length,
max_summary_length, rnn_size, vocab_to_int, keep_prob, batch_size, num_layers):
'''Create the decoding cell and attention for the training and inference decoding layers'''
for layer in range(num_layers):
with tf.variable_scope('decoder_{}'.format(layer)):
lstm = tf.contrib.rnn.LSTMCell(rnn_size,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
dec_cell = tf.contrib.rnn.DropoutWrapper(lstm,
input_keep_prob = keep_prob)
output_layer = Dense(vocab_size,
kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
attn_mech = tf.contrib.seq2seq.BahdanauAttention(rnn_size,
enc_output,
text_length,
normalize=False,
name='BahdanauAttention')
dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell,
attn_mech,
rnn_size)
initial_state = tf.contrib.seq2seq.AttentionWrapperState(enc_state[0],_zero_state_tensors(rnn_size, batch_size, tf.float32))
with tf.variable_scope("decode"):
training_logits = training_decoding_layer(dec_embed_input,
summary_length,
dec_cell,
initial_state,
output_layer,
vocab_size,
max_summary_length)
with tf.variable_scope("decode", reuse=True):
inference_logits = inference_decoding_layer(embeddings,
vocab_to_int['<GO>'],
vocab_to_int['<EOS>'],
dec_cell,
initial_state,
output_layer,
max_summary_length,
batch_size)
return training_logits, inference_logits