Эпоха Кераса проходит дважды - PullRequest
0 голосов
/ 01 мая 2020

Когда я запускаю model.fit_generator, эпоха запускается дважды, в первый раз она поднимается только до 39/40, затем во второй раз - 40/40. Любая причина, почему это происходит?

Вот GIF, вы также можете увидеть epoch 1/2 на самом деле появляется в epoch 2/2 прогоне. Это происходит только тогда, когда я передаю validation_data=validation_generator

enter image description here

Обновление, вот код:

набор данных отсюда

https://tiny-imagenet.herokuapp.com/

Пакеты:

absl-py==0.9.0
astor==0.7.1
attrs==19.3.0
autopep8==1.4.4
backcall==0.1.0
bleach==3.1.4
brotlipy==0.7.0
certifi==2020.4.5.1
cffi==1.14.0
chardet==3.0.4
colorama==0.4.3
cryptography==2.8
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
entrypoints==0.3
future==0.18.2
gast==0.2.2
google-pasta==0.2.0
grpcio==1.23.0
h5py==2.10.0
idna==2.9
imageio==2.8.0
importlib-metadata==1.6.0
ipykernel==5.2.0
ipython==7.13.0
ipython-genutils==0.2.0
jedi==0.17.0
Jinja2==2.11.2
joblib==0.14.1
json5==0.9.0
jsonschema==3.2.0
jupyter-client==6.1.3
jupyter-core==4.6.3
jupyter-tensorboard==0.2.0
jupyterlab==2.1.0
jupyterlab-server==1.1.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.2.0
llvmlite==0.31.0
Markdown==3.2.1
MarkupSafe==1.1.1
matplotlib==3.2.1
mistune==0.8.4
nbconvert==5.6.1
nbformat==5.0.6
notebook==6.0.3
numba==0.48.0
numpy==1.18.1
olefile==0.46
opt-einsum==0+untagged.56.g2664021.dirty
pandas==1.0.3
pandocfilters==1.4.2
parso==0.7.0
pickleshare==0.7.5
Pillow==7.1.1
prometheus-client==0.7.1
prompt-toolkit==3.0.5
protobuf==3.11.4
pycparser==2.20
Pygments==2.6.1
pyOpenSSL==19.1.0
pyparsing==2.4.7
PyQt5==5.12.3
PyQt5-sip==4.19.18
PyQtWebEngine==5.12.1
pyreadline==2.1
pyrsistent==0.16.0
PySocks==1.7.1
python-dateutil==2.8.1
pytz==2019.3
pywin32==227
pywinpty==0.5.7
pyzmq==19.0.0
requests==2.23.0
scikit-learn==0.22.2.post1
scipy==1.2.1
Send2Trash==1.5.0
six==1.14.0
tensorboard==1.15.0
tensorflow==1.15.0
tensorflow-estimator==1.15.1
termcolor==1.1.0
terminado==0.8.3
testpath==0.4.4
tornado==6.0.4
traitlets==4.3.3
urllib3==1.25.9
wcwidth==0.1.9
webencodings==0.5.1
Werkzeug==0.16.1
win-inet-pton==1.1.0
wincertstore==0.2
wrapt==1.12.1
zipp==3.1.0

Код

train_datagen = ImageDataGenerator(validation_split=0.9)

train_generator = train_datagen.flow_from_directory(directory= 'tiny-imagenet-200/train/', 
                                                    target_size=(64, 64), 
                                                    batch_size=256, 
                                                    class_mode='categorical', 
                                                    shuffle=True, 
                                                    seed=42,
                                                    subset ="training"
                                                   )

val_data = pd.read_csv('./tiny-imagenet-200/val/val_annotations.txt', sep='\t', header=None, names=['File', 'Class', 'X', 'Y', 'H', 'W'])
val_data.drop(['X', 'Y', 'H', 'W'], axis=1, inplace=True)

valid_datagen  = ImageDataGenerator(validation_split=0.9)

validation_generator = valid_datagen.flow_from_dataframe(dataframe=val_data, 
                                                         directory='./tiny-imagenet-200/val/images/', 
                                                         x_col='File', 
                                                         y_col='Class', 
                                                         target_size=(64, 64),
                                                         color_mode='rgb', 
                                                         class_mode='categorical', 
                                                         batch_size=256, 
                                                         shuffle=True, 
                                                         seed=42,
                                                        subset ="training")

history = model.fit_generator(train_generator, 
                    epochs=2, 
                    validation_data=validation_generator, 
                    #callbacks=[tensorboard_callback]
                             )

1 Ответ

1 голос
/ 01 мая 2020

Вы используете validation_split при создании экземпляра ImageDataGenerator и установке subset ="training" на validation_generator, но на самом деле ваши наборы проверки и обучения разделены в разных каталогах. Я не уверен на 100%, но думаю, что это может быть связано с этим.

Кроме того, при вызове flow_from_dataframe: x_col я бы использовал одни и те же общие аргументы для обучения и проверки. , y_col, target_size, color_mode, et c.

Посмотрите на примеры, показанные здесь (официальные документы):

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800) ```
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...