Я использую модель генерации текста (RNN) в Tensorfow 2.0.0-alpha0, и даже при получении метрики потерь при подгонке модели при вставке точности я получаю следующую ошибку:
InvalidArgumentError: Несовместимые фигуры: [64] против [64 200]
[[{{node metrics_4 / precision / Equal}}]]
[Op: __ inference_keras_scratch_graph_6491]
Я пытался вручную определить точность для одной партии (предварительное обучение):
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def accuracy(labels, logits):
return tf.keras.metrics.sparse_categorical_accuracy(labels,l ogits)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
example_batch_acc = accuracy(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Loss: ", example_batch_loss.numpy().mean())
print("Accuracy: ", example_batch_acc.numpy().mean())
Вывод был:
Форма прогноза: (64, 200, 34) # (batch_size, sequence_length, vocab_size)
Потеря: 3,5263805
Точность: 0,01265625
Затем я последовал:
optimizer = tf.keras.optimizers.RMSprop(lr=lr)
model.compile(optimizer=optimizer, loss=loss, metrics =['accuracy'])
history = model.fit(dataset, epochs=epochs, callbacks[checkpoint_callback])
и получил ошибку, указанную выше (потеря работает нормально). Если я попробую «точность = точность» во время компиляции, я получу:
повысить ValueError ('Аргументы ключевого слова Session не поддерживаются во время
нетерпеливое исполнение. Вы прошли:% s '% (kwargs,))
Есть мысли / предложения?