Инициализация скрытого состояния LSTM в TensorFlow - PullRequest
0 голосов
/ 04 декабря 2018

Может кто-нибудь сказать мне, в рамках TensorFlow, как инициализировать скрытые состояния сети LSTM с пользовательскими значениями?Я пытаюсь включить дополнительную информацию в LSTM, указав конкретное скрытое состояние первой ячейки LSTM.

1 Ответ

0 голосов
/ 06 декабря 2018

Вы можете передать начальное скрытое состояние LSTM с помощью параметра initial_state функции, отвечающей за развертывание графика.

Я предполагаю, что для создания рекуррентной нейронной сети (RNN) вы будете использовать некоторые из следующих функций в tenorflow: tf.nn.dynamic_rnn , bidirectional_dynamic_rnn , tf.nn.static_rnn или tf.nn.static_bidirectional_rnn .Все они имеют параметр initial_state.В случае двунаправленного RNN вам необходимо передать начальные состояния для прямого (initial_state_fw) и обратного (initial_state_bw) проходов.

Пример, который определяет модель с tf.nn.dynamic_rnn:

import tensorflow as tf

batch_size = 32
max_sequence_length = 100
num_features = 128
num_units = 64 

input_sequence = tf.placeholder(tf.float32, shape=[batch_size, max_sequence_length, num_features])
input_sequence_lengths = tf.placeholder(tf.int32, shape=[batch_size])

cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units, state_is_tuple=True)

# Initial states
cell_state = tf.zeros([batch_size, num_units], tf.float32)
hidden_state = tf.placeholder(tf.float32, [batch_size, num_units])
my_initial_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

outputs, states = tf.nn.dynamic_rnn(
                    cell=cell,
                    inputs=input_sequence,
                    initial_state=my_initial_state,
                    sequence_length=input_sequence_lengths)

Поскольку мы используем state_is_tuple=True, нам нужно передать начальное состояние, которое является кортежем cell_state и hidden_state.В документации LSTMCell этот кортеж соответствует c_state и m_state, что предыдущее обсуждение указывает на то, что оно представляет состояние ячейки и скрытое состояние соответственно.

Следовательно, поскольку мы хотим инициализировать только первое скрытое состояние, cell_state инициализируется нулями.

...