Мне кажется, что простым решением было бы просто унаследовать от CosineAnnealingWarmRestarts
, а затем изменить его параметры self.optimizer
внутри переопределенной функции step
. В псевдокоде это будет что-то вроде
class myScheduler(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
def __init__(self,
optimizer,
T_0,
T_mult=1,
eta_min=0,
last_epoch=-1):
#initialize base calss
super().__init__(.... blablabla ...)
def step(self):
#call step() from base class
#Do some book-keeping to determine if you've hit a restart
#now change optimizer lr for each parameter group
if some_condition:#condition like number of iterations, restarts, etc
self.optimizer.param_groups[i]['lr']*=some_coef