Неизвестное количество шагов - Обучение сверточной нейронной сети в Google Colab Pro - PullRequest
2 голосов
/ 29 апреля 2020

Я пытаюсь запустить (обучить) мой CNN на Google Colab Pro, когда я запускаю свой код, все в порядке, но он не знает количество шагов, поэтому создается бесконечный l oop.

Mounted at /content/drive
2.2.0-rc3
Found 10018 images belonging to 2 classes.
Found 1336 images belonging to 2 classes.
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
Epoch 1/300
      8/Unknown - 364s 45s/step - loss: 54.9278 - accuracy: 0.5410

Я использую ImageDataGenerator() для загрузки изображений. Как я могу это исправить?

1 Ответ

1 голос
/ 29 апреля 2020

Итератор ничего не хранит, он генерирует данные динамически. Когда вы используете итератор набора данных или набора данных, вы должны предоставить steps_per_epoch. Длина итератора неизвестна, пока вы не выполните итерацию. Вы можете явно передать len(datafiles) в функцию .fit. Итак, вам необходимо предоставить steps_per_epoch, как показано ниже.

model.fit_generator(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=total_val // batch_size
)

Более подробно упоминается здесь

steps_per_epoch: Integer или None. Общее количество шагов (партий образцов) до объявления одной эпохи законченной и начала следующей эпохи. При обучении с использованием входных тензоров, таких как тензоры данных TensorFlow, значение по умолчанию None равно количеству выборок в вашем наборе данных, деленному на размер пакета, или 1, если это невозможно определить. Если x - это набор данных tf.data, а 'steps_per_epoch' - None, эпоха будет работать до тех пор, пока не будет исчерпан входной набор данных. Этот аргумент не поддерживается для входных данных массива.

Я заметил, что вы используете двоичную классификацию. Еще одна вещь, которую нужно помнить, когда вы используете ImageDataGenerator, это предоставить class_mode, как показано ниже. В противном случае будет ошибка (в кератах) или 50% точность (в tf.keras).

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),class_mode='binary') #
...