Как создать параметр функции потерь, который зависит от числа эпох в Керасе? - PullRequest
0 голосов
/ 09 октября 2018

У меня есть пользовательская функция потерь с гиперпараметром alpha, которую я хочу менять каждые 20 эпох в течение тренировки.Функция потерь выглядит примерно так:

def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (alpha)* loss1 + (1-alpha)*loss2

Из моего исследования, создание собственного обратного вызова - путь.Я посмотрел на решение для аналогичного вопроса здесь и здесь , но решения не реализуют решение обратного вызова, чего я и хочу достичь.

Я попытался создать пользовательский обратный вызов, изменив обратный вызов LearningRateScheduler из keras repo

class changeAlpha(Callback):
    def __init__(self, alpha):
        super(changeAlpha, self).__init__()
        self.alpha = alpha 

    def on_epoch_begin(self, epoch, logs={}):
        if epoch%20 == 0:   
             K.set_value(self.alpha, K.get_value(self.alpha) * epoch**0.95)
             print("Setting alpha to =", str(alpha))

Однако я не уверен, что значение альфафактически соответствует альфа-значению в моей функции потерь.В любом случае, когда я помещаю обратный вызов changeAlpha в метод model.fit, я получаю attribute error.

Может ли кто-нибудь помочь мне отредактировать обратный вызов так, чтобы он изменил мой параметр alpha послеопределенное количество эпох?

1 Ответ

0 голосов
/ 09 октября 2018

Я понял вашу идею.Я думаю, что проблема в том, что альфа в функции потерь не относится к члену класса changeAlpha.Вы можете попробовать так:

instance = changeAlpha()
def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (instance.alpha*)* loss1 + (1-instance.alpha)*loss2

Или вы можете сделать альфу переменной класса, а не переменную установки, а затем изменить функцию потерь, как показано ниже:

def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (changeAlpha.alpha*)* loss1 + (1-changeAlpha.alpha)*loss2

Надеждаэто может вам помочь.

...