получить номер текущей эпохи в функции потерь - PullRequest
0 голосов
/ 02 апреля 2019

Я стремлюсь получить текущий номер эпохи в пользовательской функции потерь.Это может быть полезно для i) использования различных методов функции потерь в зависимости от эпохи или ii) использования некоторого зависимого от эпохи параметра alpha в расчете потерь.Например:

current_epoch = K.variable(0.)

def custom_loss(y_true, y_pred):
   # i) use different loss function based on epoch
   c_epoch = K.get_value(current_epoch)
   if c_epoch < t_change:
       # compute loss_1
   else:
       # compute loss_2

   # ii) also useful to do: result = (1 - alpha)*loss_1 + alpha*loss_2
   return result

Без успеха я попытался использовать обратные вызовы Lambda, чтобы получить текущую эпоху и передать ее функции потерь.В частности, я попробовал подходы, предложенные здесь и здесь .По сути, он состоит из:

from keras.callbacks import LambdaCallback

def get_epoch(epoch):
    K.set_value(current_epoch, epoch)
    # or K.set_value(alpha, functionOf(epoch))

epochchanger = LambdaCallback(on_epoch_begin=get_epoch)

Наконец, скомпилируйте и установите:

model.compile(loss=custom_loss, metrics=...)
model.fit_generator(..., callbacks=[epochchanger])

Var current_epoch корректно обновляется каждую эпоху в get_epoch().Но обновленный current_epoch не достигает функции custom_loss.Вместо этого current_epoch в функции custom_loss остается со значением инициализации навсегда.

Есть какие-либо предложения о том, как получить обновленный current_epoch в функции потерь?Примечание: перекомпиляция модели была бы способом изменить функцию потерь, но я стараюсь этого не делать, поскольку было бы неоптимальным изменить состояние оптимизатора.Спасибо!

...