Возобновить обучение с оптимизатором Adam в Керасе - PullRequest
2 голосов
/ 05 февраля 2020

Мой вопрос довольно простой, но я не могу найти однозначного ответа онлайн (пока).

Я сохранил веса модели керас, обученной с помощью оптимизатора Адама, после определенного количества эпох тренировка с использованием:

callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])

Когда я возобновлю тренировку после закрытия моего юпитера, могу ли я просто использовать:

model.load_weights(path)

, чтобы продолжить тренировку.

Поскольку Адам зависит от номер эпохи (например, в случае спада скорости обучения), я хотел бы знать, как проще всего возобновить обучение в тех же условиях, что и раньше.

После ответа ibarrond я написал небольшой пользовательский обратный вызов .

optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])

weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)

class optim_callback(tf.keras.callbacks.Callback):
    '''Custom callback to save optimiser state'''

          def on_epoch_end(self,epoch,logs=None):
                optim_state = tf.keras.optimizers.Adam.get_config(optim)
                with open(optim_state_pkl,'wb') as f_out:                  
                       pickle.dump(optim_state,f_out)

model.fit(X,y,callbacks=[weight_callback,optim_callback()])

Когда я возобновлю обучение:

model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:                  
                    optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)

Я просто хотел бы проверить, правильно ли это. Еще раз большое спасибо !!

Приложение: При дальнейшем чтении стандартной реализации Кераса Адама и оригинальной статьи Адама , я считаю, что Адам по умолчанию не зависит по номеру эпохи, но только по номеру итерации. Поэтому в этом нет необходимости. Однако этот код может быть полезен всем, кто хочет отслеживать другие оптимизаторы.

1 Ответ

5 голосов
/ 05 февраля 2020

Для того, чтобы идеально зафиксировать состояние вашего оптимизатора, вы должны сохранить его конфигурацию, используя функцию get_config(). Эта функция возвращает словарь (содержащий опции) , который можно сериализовать и сохранить в файле, используя pickle.

Чтобы перезапустить процесс, просто d = pickle.load('my_saved_tfconf.txt') чтобы получить словарь с конфигурацией, а затем сгенерировать свой оптимизатор Адама, используя функцию from_config(d) оптимизатора Keras Adam .

...