Я пытаюсь построить сеть LSTM тензорного потока без использования Keras API. Модель очень проста:
- ввод последовательности из 4-х индексов слов
- встраивание входных векторов 100 dim word
- проход через слой LSTM
- плотный слой с выводом последовательности из 4 слов
Функция потери - это потеря последовательности.
У меня есть следующий код:
# input
input_placeholder = tf.placeholder(tf.int32, shape=[config.batch_size, config.num_steps], name='Input')
labels_placeholder = tf.placeholder(tf.int32, shape=[config.batch_size, config.num_steps], name='Target')
# embedding
embedding = tf.get_variable('Embedding', initializer=embedding_matrix, trainable=False)
inputs = tf.nn.embedding_lookup(embedding, input_placeholder)
inputs = [tf.squeeze(x, axis=1) for x in tf.split(inputs, config.num_steps, axis=1)]
# LSTM
initial_state = tf.zeros([config.batch_size, config.hidden_size])
lstm_cell = tf.nn.rnn_cell.LSTMCell(config.hidden_size)
output, _ = tf.keras.layers.RNN(lstm_cell, inputs, dtype=tf.float32, unroll=True)
# loss op
all_ones = tf.ones([config.batch_size, config.num_steps])
cross_entropy = tfa.seq2seq.sequence_loss(output, labels_placeholder, all_ones, vocab_size)
tf.add_to_collection('total_loss', cross_entropy)
loss = tf.add_n(tf.get_collection('total_loss'))
# projection (dense)
proj_U = tf.get_variable('Matrix', [config.hidden_size, vocab_size])
proj_b = tf.get_variable('Bias', [vocab_size])
outputs = [tf.matmul(o, proj_U) + proj_b for o in output]
У меня проблема в том, что в части LSTM сейчас:
# tensorflow 1.x
output, _ = tf.contrib.rnn.static_rnn(
lstm_cell, inputs, dtype = tf.float32,
sequence_length = [config.num_steps]*config.batch_size)
У меня возникли проблемы с преобразованием этого значения в tenorslow 2. В приведенном выше коде я получаю следующую ошибку:
--- -------------------------------------------------- ---------------------- TypeError Traceback (последний последний вызов) в ----> 1 выходах, _ = tf.keras.layers.RNN (lstm_cell , input, dtype = tf.float32, unroll = True)
TypeError: невозможно распаковать не повторяемый объект RNN