ModelCheckpoint в керасе сравнивают со старой моделью - PullRequest
0 голосов
/ 25 февраля 2020

Я новичок в глубоком обучении и занимаюсь некоторыми проблемами классификации.

Я использую EarlyStopping и ModelCheckpoint в своем списке обратных вызовов, но когда начинается обучение, базовая линия контрольной точки модели является отрицательной бесконечностью и переписать 'best_model.h5'.

Однако, best_model.h5 уже хранит мою последнюю лучшую модель. Я хочу установить базовую линию ModelCheckpoint для производительности моей последней лучшей модели по данным.

Кто-нибудь может мне помочь?

es = EarlyStopping(monitor='val_accuracy', mode='max', verbose=1, patience=3)
mc = ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', save_best_only=True, verbose=1)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, validation_data=(x_valid, y_valid), batch_size=400,\
                  epochs=20, callbacks=[es, mc])

введите описание изображения здесь

Ответы [ 2 ]

0 голосов
/ 25 февраля 2020

Я думаю, что ваша проблема была в том, что вы хотели сохранить val_a cc до первой эпохи. Возвращаясь к механизму общей проблемы машинного обучения, я не думаю, что значение точности до первой итерации имеет смысл для сравнения. (ваша модель не была обучена по данному набору данных). Если вы хотите, вы можете проверить потерю проверки (val_loss), если это возможно.

Но если вы хотите сохранить журнал вашего тренировочного процесса, вам не нужно сохранять модель для каждой эпохи. Вы можете использовать функцию истории как (импортировать matplotlib.pyplot как plt)

results = model.fit(x_train, y_train, validation_data=(x_valid, y_valid), batch_size=400,epochs=20, callbacks=[es, mc])

plt.figure(figsize=(8, 8))
plt.title("Learning curve")
plt.plot(results.history["loss"], label="loss")
plt.plot(results.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig('loss.png')

plt.figure(figsize=(8, 8))
plt.title("Learning curve")
plt.plot(results.history["acc"], label="accuracy")
plt.plot(results.history["val_acc"], label="accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig('acc.png')
0 голосов
/ 25 февраля 2020

Сделайте это:

mc = ModelCheckpoint('best_model-{epoch:04d}_{val_accuracy:.2f}.h5', monitor='val_accuracy', mode='max', save_best_only=True, verbose=1)

Это сохранит вашу новую лучшую модель с номером epoch и validation_accuracy без перезаписи best_model.h5. Позже это поможет вам выбрать лучшие модели и сравнить.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...