Я пытаюсь использовать одну и ту же архитектуру LSTM для разных входов и, следовательно, проходить одни и те же ячейки, разворачивая двунаправленный LSTM, одновременно разворачивая разные входы. Я не уверен, что он создает две совершенно разные сети LSTM. Похоже, в моем графике есть два разных узла. Мой код и график выглядят примерно так:
def get_multirnn_cell(self):
cells = []
for _ in range(config.n_layers):
cell = tf.nn.rnn_cell.LSTMCell(config.n_hidden, initializer=tf.glorot_uniform_initializer())
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell,
input_keep_prob=config.keep_prob,
output_keep_prob=config.keep_prob)
cells.append(dropout_cell)
return cells
def add_lstm_op(self):
with tf.variable_scope('lstm'):
cells_fw = self.get_multirnn_cell()
cells_bw = self.get_multirnn_cell()
cell_fw = tf.nn.rnn_cell.MultiRNNCell(cells_fw)
cell_bw = tf.nn.rnn_cell.MultiRNNCell(cells_bw)
(_, _), (state_one_fw, state_one_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
inputs=self.question_one,
sequence_length=self.seql_one,
dtype=tf.float32)
self.state_one = tf.concat([state_one_fw[-1].h, state_one_bw[-1].h], name='state_one', axis=-1)
# self.state_one = tf.concat([state_one_fw, state_one_bw], axis=-1)
# [batch_size, 2*hidden_size]
(_, _), (state_two_fw, state_two_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
inputs=self.question_two,
sequence_length=self.seql_two,
dtype=tf.float32)
self.state_two = tf.concat([state_two_fw[-1].h, state_two_bw[-1].h], name='state_two', axis=-1)