ValueError при двойном использовании tf.nn.dynamic_rnn - PullRequest
0 голосов
/ 26 апреля 2020

У меня есть модель последовательности на основе LSTM:

    lstmCell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(lstm_units),tf.contrib.rnn.BasicLSTMCell(lstm_units)])
    '''hidden_output is the hidden state vectors of lstm cells [batch_size * time_steps * lstm_units]'''
    hidden_output, _ = tf.nn.dynamic_rnn(lstmCell,embedding_vec,dtype=tf.float32)

В настоящее время я хочу изменить модель на две последовательности LSTM и объединить их на скрытом выходе. Поэтому я добавляю код:

lstmCell_2 = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(lstm_units),tf.contrib.rnn.BasicLSTMCell(lstm_units)])
'''hidden_output is the hidden state vectors of lstm cells [batch_size * time_steps * lstm_units]'''
hidden_output_2, _ = tf.nn.dynamic_rnn(lstmCell_2,path_2_input,dtype=tf.float32)

Я собираюсь объединить последний шаг hidden_output и hidden_output_2. Тем не менее, я всегда получаю ValueError во второй используемый tf.nn.dynamic_rnn:

    ValueError: in converted code:
        relative to /Users/arielxiao/Library/Python/3.7/lib/python/site-packages/tensorflow/python:

        ops/rnn_cell_impl.py:1719 call *
            cur_inp, new_state = cell(cur_inp, cur_state)
        ops/rnn_cell_impl.py:385 __call__
            self, inputs, state, scope=scope, *args, **kwargs)
        layers/base.py:537 __call__
            outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
        keras/engine/base_layer.py:591 __call__
            self._maybe_build(inputs)
        keras/engine/base_layer.py:1881 _maybe_build
            self.build(input_shapes)
        keras/utils/tf_utils.py:295 wrapper
            output_shape = fn(instance, input_shape)
        ops/rnn_cell_impl.py:734 build
            shape=[input_depth + h_depth, 4 * self._num_units])
        keras/engine/base_layer.py:1484 add_variable
            return self.add_weight(*args, **kwargs)
        layers/base.py:450 add_weight
            **kwargs)
        keras/engine/base_layer.py:384 add_weight
            aggregation=aggregation)
        training/tracking/base.py:663 _add_variable_with_custom_getter
            **kwargs_for_getter)
...

Означает ли это, что этот API нельзя использовать более одного раза? Есть ли способ удовлетворить мое требование?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...