Я использую код MurtiShikhar в качестве ссылки.
Я создал ячейку (которая является входом для внимания_wrapper.py) как:
with tf.variable_scope("match_lstm_attender"):
attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
cell = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
# we don't mask the passage because masking the memories will be handled by the pointerNet
reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
Но это дает
«Ошибка типа: ячейка должна быть RNNCell, тип пилы: BasicLSTMCell».
Как предотвратить эту ошибку?