сохранить веса моделей в конце каждой N эпох - PullRequest
0 голосов
/ 05 июля 2018

Я тренирую NN и хотел бы сохранить весовые коэффициенты модели каждые N эпох для фазы прогнозирования. Я предлагаю этот черновик кода, он вдохновлен ответом @grovina здесь . Не могли бы вы сделать предложения? Заранее спасибо.

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

Затем добавьте это к вызову подгонки: чтобы сохранять веса каждые 5 эпох:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

Ответы [ 2 ]

0 голосов
/ 05 июля 2018

Вы должны реализовать на on_epoch_end, а не on_batch_end. А также передача модели в качестве аргумента для __init__ является избыточной.

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1
0 голосов
/ 05 июля 2018

Вам не нужно передавать модель для обратного вызова. У него уже есть доступ к модели через супер. Так что удалите __init__(..., model, ...) аргумент и self.model = model. Вы должны иметь возможность доступа к текущей модели через self.model независимо. Вы также сохраняете его на каждом конце партии, а это не то, что вам нужно, вероятно, вы хотите, чтобы оно было on_epoch_end.

Но в любом случае то, что вы делаете, может быть сделано с помощью наивного modelcheckpoint callback . Вам не нужно писать собственный. Вы можете использовать это следующим образом:

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...