Понятия не имею, как возникает эта ошибка. Я пытаюсь изменить формат ввода на RNN и распечатать тензоры в исходной версии (которая работает) и измененной версии (которая вылетает).
ФУНКЦИОНАЛЬНАЯ:
LABEL= Tensor("concat_1:0", shape=(?, 2), dtype=float32, device=/device:CPU:0) (?, 2)
inputs=Tensor("concat:0", shape=(?, 8), dtype=float32, device=/device:CPU:0)
x=[<tf.Tensor 'split:0' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:1' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:2' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:3' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:4' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:5' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:6' shape=(?, 1) dtype=float32>,
<tf.Tensor 'split:7' shape=(?, 1) dtype=float32>]
last outputs=Tensor("rnn/rnn/basic_lstm_cell/mul_23:0", shape=(?, 3), dtype=float32)
PREDICTION Tensor("add:0", shape=(?, 2), dtype=float32)
LOSS Tensor("mean_squared_error/value:0", shape=(), dtype=float32)
СЛОМЛЕННОЕ:
X= 5 Tensor("Const:0", shape=(49, 10), dtype=float32, device=/device:CPU:0)
labels= Tensor("Const_5:0", shape=(49, 10), dtype=float32)
OUTPUTS Tensor("rnn/rnn/basic_lstm_cell/mul_14:0", shape=(49, 5), dtype=float32)
PREDICTIONS Tensor("add:0", shape=(49, 10), dtype=float32)
LABELS Tensor("Const_5:0", shape=(49, 10), dtype=float32)
LOSS Tensor("mean_squared_error/value:0", shape=(), dtype=float32)
Вот код для модели, который одинаков для каждой из них:
lstm_cell = rnn.BasicLSTMCell(LSTM_SIZE, forget_bias=1.0)
outputs, _ = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)
outputs = outputs[-1]
print('-->OUTPUTS', outputs)
weight = tf.Variable(tf.random_normal([LSTM_SIZE, N_OUTPUTS]))
bias = tf.Variable(tf.random_normal([N_OUTPUTS]))
predictions = tf.matmul(outputs, weight) + bias
print('-->PREDICTIONS', predictions)
print('-->LABELS', labels)
loss = tf.losses.mean_squared_error(labels, predictions)
print('-->LOSS', loss)
train_op = tf.contrib.layers.optimize_loss(loss=loss, global_step=tf.train.get_global_step(), learning_rate=0.01, optimizer="SGD")
eval_metric_ops = {"rmse": tf.metrics.root_mean_squared_error(labels, predictions)}