обновите код rnn.static_bidirectional_rnn, чтобы он соответствовал API-интерфейсу tenorflow 2.0 - PullRequest
1 голос
/ 07 мая 2019
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)

Приведенный выше код работает с tenorflow 1.x, однако мне сложно найти способ переписать этот код с помощью API tenorflow 2.0.

Я знаю, что мне следует начать с tf.keras.layers.LSTMCell (), но я не знаю, какая функция API подходит для 2 экземпляров LSTMCell в качестве входных данных.

1 Ответ

1 голос
/ 07 мая 2019

Эквивалент Keras для вашего фрагмента будет

lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
keras.layers.Bidirectional(lstm)

Обратите внимание, что, хотя Keras имеет реализацию LSTMCell, вы можете вместо этого использовать LSTM, которая представляет собой не просто ячейку, а полностью развернутый RNN, работающий сразу за всей последовательностью. По умолчанию RNN динамически разворачивается через цикл while, мы заставляем его быть статическим (в терминах TF 1.X), передавая unroll=True. Наконец, оболочка keras.layers.Bidirectional делает двунаправленный RNN.

...