Конвертировать model.fit_generator в model.fit - PullRequest
0 голосов
/ 08 марта 2020

У меня есть следующие коды,

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

Теперь model.fit_generator определяется следующим образом:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

Теперь model.fit_generator устарело, как правильно изменить model.fit_generator на model.fit в этом случае?

Ответы [ 2 ]

2 голосов
/ 08 марта 2020

Вам просто нужно изменить model.fit_generator() на model.fit().

Начиная с TensorFlow 2.1, model.fit() также принимает генераторы в качестве входных данных. Так просто.

Из официальной документации TensorFlow:

Предупреждение: ЭТА ФУНКЦИЯ УСТАРЕЛА. Он будет удален в следующей версии. Инструкции по обновлению: Пожалуйста, используйте Model.fit, который поддерживает генераторы.

0 голосов
/ 03 апреля 2020

старое обучение = model.fit_generator (генератор = train_generator, steps_per_epoch = 2048 // 36, эпох = 10, validation_data = validation_generator, validation_steps = 832 // 16)

избавиться от 'генератора =' нового training = model.fit (train_generator, steps_per_epoch = 2048 // 128, эпох = 10, validation_data = validation_generator, validation_steps = 832 // 16)

...