Что делать, если steps_per_epoch не вписывается в число образцов? - PullRequest
0 голосов
/ 01 июня 2018

с использованием Keras fit_generator, steps_per_epoch должно быть эквивалентно общему количеству доступных выборок, деленному на batch_size.

Но как отреагирует генератор или fit_generator, если я выберу batch_size, который не подходит n раз к выборкам?Дает ли он выборки до тех пор, пока не сможет больше заполнить batch_size или просто использует меньший batch_size для последнего выхода?

Почему я спрашиваю: я делю свои данные на поезда / проверки / испытания разного размера (разные%), но использую один и тот же размер партии для наборов поездов и проверок, но особенно для наборов поездов и испытаний.Поскольку они различаются по размеру, я не могу гарантировать, что размер партии вписывается в общее количество образцов.

Ответы [ 2 ]

0 голосов
/ 01 июня 2018

Если вы назначите N параметру steps_per_epoch из fit_generator(), Keras в основном будет вызывать ваш генератор N раз, прежде чем считать, что одна эпоха завершена.Генератор может выдавать все ваши сэмплы в N пакетах.

Обратите внимание, что, поскольку для большинства моделей хорошо иметь разные размеры пакетов на каждой итерации, вы можете исправить steps_per_epoch = ceil(dataset_size / batch_size) и позволить вашему генератору выводить данныеменьшая партия для последних образцов.

0 голосов
/ 01 июня 2018

Если это ваш генератор с yield

Это вы создаете генератор, поэтому поведение определяется вами.

Если steps_per_epoch больше ожидаемых партий, подгонитеничего не увидит, просто будет продолжать запрашивать пакеты, пока не достигнет количества шагов.

Единственное: вы должны убедиться, что ваш генератор бесконечен.

Сделайте это, например, с while True: в начале.

Если это генератор из ImageDataGenerator.

Если генератор из ImageDataGenerator, это на самом деле keras.utils.Sequence и он имеет свойство length: len(generatorInstance).

Тогда вы можете сами проверить, что происходит:

remainingSamples = total_samples % batch_size #confirm that this is gerater than 0
wholeBatches = total_samples // batch_size
totalBatches = wholeBatches + 1

if len(generator) == wholeBatches:
    print("missing the last batch")    
elif len(generator) == totalBatches:
    print("last batch included")
else:
    print('weird behavior')

и проверить размер последней партии:

lastBatch = generator[len(generator)-1]

if lastBatch.shape[0] == remainingSamples:
    print('last batch contains the remaining samples')
else:
    print('last batch is different')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...