Как сделать так, чтобы ячейка Basi c LSTM передавалась в качестве ввода в Внимание_wrapper.py? - PullRequest
0 голосов
/ 26 апреля 2020

Я использую код MurtiShikhar в качестве ссылки.

Я создал ячейку (которая является входом для внимания_wrapper.py) как:

with tf.variable_scope("match_lstm_attender"):
            attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
            cell = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
            lstm_attender  = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)

            # we don't mask the passage because masking the memories will be handled by the pointerNet
            reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)

            output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")    
            output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")

            output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)

Но это дает

«Ошибка типа: ячейка должна быть RNNCell, тип пилы: BasicLSTMCell».

Как предотвратить эту ошибку?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...