Остановка и возобновление тренировки на VGG-16 - PullRequest
0 голосов
/ 24 августа 2018

Я использую предварительно обученную модель VGG-16 для классификации изображений. Я добавляю пользовательский последний слой, так как количество моих классов классификации равно 10. Я тренирую модель в течение 200 эпох.

Мой вопрос: есть ли способ, если я случайно остановлю (закрыв окно Python) обучение в какой-то эпохе, скажем, в эпоху нет. 50 и резюме оттуда? Я читал о сохранении и перезагрузке модели, но, насколько я понимаю, она работает только для наших пользовательских моделей, а не для предварительно обученных моделей, таких как VGG-16.

Ответы [ 2 ]

0 голосов
/ 24 августа 2018

Вот настроенная версия ModelCheckpoint , которую я использую для возобновления обучения с определенной эпохи, gist .Он сохранит эпоху и другие журналы в соответствующий файл JSON, а также проверит, следует ли возобновить обучение или нет при запуске.Вам нужно позвонить get_last_epoch и установить initial_epoch в model.fit, чтобы возобновить эту эпоху.

import json

class StatefulCheckpoint(ModelCheckpoint):
  """Save extra checkpoint data to resume training."""
  def __init__(self, weight_file, state_file=None, **kwargs):
    """Save the state (epoch etc.) along side weights."""
    super().__init__(weight_file, **kwargs)
    self.state_f = state_file
    self.state = dict()
    if self.state_f:
      # Load the last state if any
      try:
        with open(self.state_f, 'r') as f:
          self.state = json.load(f)
        self.best = self.state['best']
      except Exception as e: # pylint: disable=broad-except
        print("Skipping last state:", e)

  def on_train_begin(self, logs=None):
    prefix = "Resuming" if self.state else "Starting"
    print("{} training...".format(prefix))

  def on_epoch_end(self, epoch, logs=None):
    """Saves training state as well as weights."""
    super().on_epoch_end(epoch, logs)
    if self.state_f:
      state = {'epoch': epoch+1, 'best': self.best}
      state.update(logs)
      state.update(self.params)
      with open(self.state_f, 'w') as f:
        json.dump(state, f)

  def get_last_epoch(self, initial_epoch=0):
    """Return last saved epoch if any, or return default argument."""
    return self.state.get('epoch', initial_epoch)
0 голосов
/ 24 августа 2018

Вы можете использовать обратный вызов ModelCheckpoint для регулярного сохранения модели. Чтобы использовать его, передайте параметр callbacks методу fit:

from keras.callbacks import ModelCheckpoint
checkpointer = ModelCheckpoint(filepath='model-{epoch:02d}.hdf5', ...)
model.fit(..., callbacks=[checkpointer])

Затем, позже, вы можете загрузить последнюю сохраненную модель. Для дополнительной настройки этого обратного вызова обратитесь к документации.

...