Как Робби предлагает в другом ответе , похоже, что ваша старая реализация использовала фиксированные размеры пакетов повсюду (предположительно, с использованием API, подобного tf.train.batch()
или одной из его оболочек с Аргумент по умолчанию allow_smaller_final_batch=False
), а поведение пакетирования по умолчанию в tf.data
(через tf.data.Dataset.batch()
и tf.contrib.data.map_and_batch()
) состоит в том, чтобы включить меньшую конечную партию.
Ошибка наиболее вероятна в model_fn
. Без этой функции трудно угадать, но я подозреваю, что существует либо явное (и неверное) утверждение формы тензора с помощью Tensor.set_shape()
(возможно, в коде библиотеки), либо ошибка в реализация tf.losses.sparse_softmax_cross_entropy()
.
Во-первых, я предполагаю, что тензоры features
и labels
, возвращаемые из input_fn()
, имеют статически неизвестный размер партии. Можете ли вы подтвердить это, напечатав объекты features
и labels
и убедившись в том, что их сообщенные свойства Tensor.shape
имеют None
для 0-го измерения?
Далее найдите вызов на tf.losses.sparse_softmax_cross_entropy()
в вашем model_fn
. Распечатайте объект, который передан в качестве аргумента weights
этой функции, который должен быть tf.Tensor
, и найдите его статическую форму. Учитывая ошибку, которую вы видите, я подозреваю, что она будет иметь форму, такую как (110,)
, где 110
- указанный вами размер пакета. Если это так, в model_fn
есть ошибка, которая неверно утверждает, что форма весов является полной партией, а может и не быть. (Если это не так, то в tf.losses.sparse_softmax_cross_entropy()
есть ошибка! Пожалуйста, откройте проблему GitHub с примером, который позволяет нам воспроизвести проблему.)
В сторону: Почему это объясняет ошибку? Код , который вызывает сбой tf.where()
op, выглядит следующим образом (отредактировано для удобства чтения):
num_present = tf.where(tf.equal(weights, 0.0), # This input is shape [74]
tf.zeros_like(weights), # This input is shape [110]
tf.ones_like(weights) # This input is probably [110]
)
Этот вариант tf.where()
op (названный "Select"
в сообщении об ошибке по историческим причинам) требует, чтобы все три входа имели одинаковый размер. Внешне tf.equal(weights, 0.0)
, tf.ones_like(weights)
и tf.zeros_like(weights)
имеют одинаковую форму, которая является формой weights
. Однако если статическая форма (результат Tensor.shape
) отличается от динамической формы , то поведение не определено.
Что на самом деле происходит? В данном конкретном случае, скажем, статическая форма weights
равна [110]
, но динамическая форма равна [74]
. Статическая форма наших трех аргументов tf.where()
будет [110]
. Реализация tf.equal()
не заботится о несоответствии, поэтому ее динамическая форма будет [74]
. Реализации tf.zeros_like()
и tf.ones_like()
используют оптимизацию , которая игнорирует эту динамическую форму, когда статическая форма полностью определена, и поэтому их динамические формы будут [110]
, вызывая ошибку, которую вы видите.
Правильное исправление заключается в том, чтобы найти код, который устанавливает фиксированный размер пакета в вашем model_fn
, и удалить его. Логика оптимизации и оценки в TensorFlow устойчива к переменным размерам партий, и это гарантирует, что все ваши данные будут использованы в процессах обучения и оценки.
Менее желательным краткосрочным решением было бы удаление небольшой партии в конце данных. Здесь есть несколько вариантов: