Почему вход тензорного потока SimpleRNNCell является 3D? - PullRequest
1 голос
/ 29 марта 2019

Я тестирую тензор потока tf.keras.layers.SimpleRNNCell.Я нахожу это таким странным.Я думаю, что ячейка RNN - это единица для получения предыдущего состояния a^{<t-1>} и текущего ввода данных x^{<t>}.Он выведет новое состояние a^{<t>} и текущий прогноз \hat{y}^{<t>}.

Таким образом, вход SimpleRNNCell должен быть 2d, если установлен batch_size.Я думаю, что ввод должен быть [batch_size,feature_size].Однако это вызовет ошибку, если вход 2D.И предыдущие состояния также нуждаются в 3D.

Правильный код выглядит следующим образом:

batch_data = tf.ones((batch_size, time_steps, label_num))    
simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)
initial_state = tf.zeros((batch_size, time_steps, units))
output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)

Однако, я думаю, следующий код был правильным.Но я не прав

batch_data = tf.ones((batch_size, label_num))    
simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)
initial_state = tf.zeros((batch_size, units))
output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)

Так что мой вопрос, почему ввод SimpleRNNCell является 3D?

1 Ответ

0 голосов
/ 29 марта 2019

Вход RNN (или LSTM) должен иметь форму [batch_size, timesteps, nbr_features]

...