Прогнозирование пакетов с использованием API данных Tensorflow и модели Keras - PullRequest
0 голосов
/ 17 октября 2019

Предположим, у меня есть набор данных и модель Keras. Набор данных был разделен на партии с использованием batch() в API набора данных. Теперь я ищу эффективный и чистый способ сделать пакетные прогнозы для всех тестируемых образцов.

Я попробовал следующий код, и он работает.

batch_size = 32
dataset = dataset.batch(batch_size)
predictions = keras_model.predict(dataset, steps=math.ceil(num_testing_samples / batch_size))

Интересно, есть ли более эффективныйи элегантный подход для реализации этого?

1 Ответ

0 голосов
/ 17 октября 2019

TF> = 1.14.0

Вы можете просто установить steps=None. Из официальной документации tf.keras.Model.predict():

Если x является набором данных tf.data, а значение шагов равно None, прогнозирование будет выполняться, пока не будет исчерпан входной набор данных.

Просто убедитесь, что ваш dataset объект не находится в режиме повтора, и вы готовы к работе :).

TF 1.12.0 и 1.13.0

Поддержка tf.data.Dataset с tf.keras очень плохая в этих версиях. Объект tf.data.Dataset преобразуется в итератор здесь , который затем вызывает ошибку здесь , если вы не установили аргумент steps. Это исправлено в 1.14.0.

...