Генераторы дополнения данных не работают с TensorFlow 2.0 - PullRequest
0 голосов
/ 20 февраля 2020

Я пытаюсь обучить модель с генераторами увеличения данных изображения на TensorFlow 2.0, после загрузки набора данных Kaggle cats_vs_dogs, используя следующий код.

train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   rotation_range=40,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(train_dir,  
                                                    target_size=(150, 150), 
                                                    batch_size=32,
                                                    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(validation_dir,  
                                                    target_size=(150, 150), 
                                                    batch_size=32,
                                                    class_mode='binary')

history = model.fit_generator(train_generator,
                              steps_per_epoch=100,
                              epochs=100,
                              validation_data=validation_generator,
                              validation_steps=50)

Но в первую эпоху получаю эту ошибку:

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
WARNING:tensorflow:From <ipython-input-18-e571f2719e1b>:27: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
Train for 100 steps, validate for 50 steps
Epoch 1/100
 63/100 [=================>............] - ETA: 59s - loss: 0.7000 - accuracy: 0.5000 WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 10000 batches). You may need to use the repeat() function when building your dataset.

Как мне изменить вышеуказанную базу кода для TensorFlow 2?

1 Ответ

0 голосов
/ 20 февраля 2020

Набор данных kaggle содержит 25000 примеров обучения. Сообщение об ошибке гласит, что:

Тензор потока цитат: в вашем вводе закончились данные; прерывающее обучение. Убедитесь, что ваш набор данных или генератор может генерировать не менее steps_per_epoch * epochs пакетов (в данном случае 10000 пакетов). Вам может понадобиться использовать функцию repeat () при построении набора данных.

Это означает, что генератору данных необходимо сгенерировать как минимум 10000 пакетов. Но при текущем размере партии 32 генератор будет производить только 25000/32, что примерно равно 781 партии. Мое предложение - попытаться уменьшить steps_per_epoch или epochs и попробовать.

Вы можете избавиться от сообщения об устаревании, передав объект генератора в model.fit(...) вместо model.fit_generator

...