Контроль эпох и партий в Керасе - PullRequest
0 голосов
/ 19 ноября 2018

Я хотел бы реализовать модель автоэнкодера, которая действует следующим образом:

for epoch in xrange(100):
  for X_batch in batch_list:
     model.train_on_batch(X_batch, X_batch)
     training_error = model.evaluate(X_batch, X_batch, verbose=0)
  average the training error by the number of the batches considered
  save it as the epoch training error
  call the function to get the validation error in the same fashion over  the validation data
  compare the two errors and decide whether go on training or stopping

Я посмотрел в Интернете и уже что-то спросил, и мне предложили использовать fit_generator, но я не понимаю, как это реализовать. Или я должен использовать метод train_on_batch или метод подгонки с числом эпох, равным 1, для точного подбора модели?

Какая лучшая практика в этом случае? У вас есть пример или похожий вопрос, чтобы связать меня?

1 Ответ

0 голосов
/ 19 ноября 2018

Насколько я понимаю, вы хотите использовать ошибку проверки в качестве критерия ранней остановки.Хорошая новость заключается в том, что у keras уже есть ранняя остановка обратного вызова.Так что все, что вам нужно, это создать обратный вызов и вызвать его во время обучения после нескольких эпох / итераций.

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

Давайте посмотрим на train_on_batch и fit ()

train_on_batch(x, y, sample_weight=None, class_weight=None)


fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

Вы можете увидетьэтот train_on_batch не принимает никакого обратного вызова в качестве входных данных, поэтому хорошим выбором будет использовать здесь фитинг, если только вы не хотите реализовать его самостоятельно.

Теперь вы можете вызывать фитинг следующим образом

callbacks = [EarlyStopping(monitor='val_loss', patience=2),
         ModelCheckpoint(filepath='path to latest ckpt', monitor='val_loss', save_best_only=True)]

history = model.fit(train_features,train_target, epochs=num_epochs, callbacks=callbacks, verbose=0, batch_size=your_choice, validation_data) 
...