Несовместимые формы для точности при подгонке модели tenorflow2 - PullRequest
0 голосов
/ 28 апреля 2019

Я использую модель генерации текста (RNN) в Tensorfow 2.0.0-alpha0, и даже при получении метрики потерь при подгонке модели при вставке точности я получаю следующую ошибку:

InvalidArgumentError: Несовместимые фигуры: [64] против [64 200]
[[{{node metrics_4 / precision / Equal}}]] [Op: __ inference_keras_scratch_graph_6491]

Я пытался вручную определить точность для одной партии (предварительное обучение):

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def accuracy(labels, logits):
    return tf.keras.metrics.sparse_categorical_accuracy(labels,l ogits)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
example_batch_acc  = accuracy(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Loss:      ", example_batch_loss.numpy().mean())
print("Accuracy:      ", example_batch_acc.numpy().mean())

Вывод был:

Форма прогноза: (64, 200, 34) # (batch_size, sequence_length, vocab_size) Потеря: 3,5263805 Точность: 0,01265625

Затем я последовал:

optimizer = tf.keras.optimizers.RMSprop(lr=lr) 
model.compile(optimizer=optimizer, loss=loss, metrics =['accuracy']) 
history = model.fit(dataset, epochs=epochs, callbacks[checkpoint_callback]) 

и получил ошибку, указанную выше (потеря работает нормально). Если я попробую «точность = точность» во время компиляции, я получу:

повысить ValueError ('Аргументы ключевого слова Session не поддерживаются во время нетерпеливое исполнение. Вы прошли:% s '% (kwargs,))

Есть мысли / предложения?

1 Ответ

0 голосов
/ 28 апреля 2019

accuracy не является стандартным аргументом Model.fit - он будет принят под **kwargs, который затем будет передан в session.run в графическом режиме.Попробуйте metrics=[accuracy].

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...