Что-нибудь вроде , это сработает. Я не тестировал это и не пытался доводить до совершенства ... но детали есть, так что вы можете заставить его работать так, как вам нравится.
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np
class LRSetter(Callback):
def __init__(self, start_lr=0, middle_lr=0.001, end_lr=0.00001,
start_mid_batches=200, end_epochs=2000):
self.start_mid_lr = np.linspace(start_lr, middle_lr, start_mid_batches)
#Not exactly right since you'll have gone through a couple epochs
#but you get the picture
self.mid_end_lr = np.linspace(middle_lr, end_lr, end_epochs)
self.start_mid_batches = start_mid_batches
self.epoch_takeover = False
def on_train_batch_begin(self, batch, logs=None):
if batch < self.start_mid_batches:
tf.keras.backend.set_value(self.model.optimizer.lr, self.start_mid_lr[batch])
else:
self.epoch_takeover = True
def on_epoch_begin(self, epoch):
if self.epoch_takeover:
tf.keras.backend.set_value(self.model.optimizer.lr, self.mid_end_lr[epoch])