Расходящиеся потери в Керасе с нестандартной потерей - PullRequest
3 голосов
/ 27 сентября 2019

У меня есть полностью подключенная сеть прямой связи, реализованная с помощью Keras.Первоначально я использовал двоичную кросс-энтропию в качестве потерь и метрики, а оптимизатор Адама следующим образом:

adam = keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
model.compile(optimizer=adam, loss='binary_crossentropy', metrics=['binary_crossentropy']) 

Эта модель хорошо работает и дает хорошие результаты.Чтобы получить лучшие результаты, я хочу использовать другую функцию потерь и метрику, как показано ниже:

import keras.backend as K

def soft_bit_error_loss(yTrue, yPred):
    loss = K.pow(1 - yPred, yTrue) * K.pow(yPred, 1-yTrue)
    return K.mean(loss)

def ber(yTrue, yPred):
    x_hat_train = K.cast(K.greater(yPred, 0.5), 'uint8')
    train_errors = K.cast(K.not_equal(K.cast(yTrue, 'uint8'), x_hat_train), 'float32')
    train_ber = K.mean(train_errors)
    return train_ber

Я использую ее для компиляции моей модели, как показано ниже

model.compile(optimizer=adam, loss=soft_bit_error_loss, metrics=[ber])

Однако, когда ясделать это, потери и метрика расходятся после некоторой тренировки, каждый раз как на следующих рисунках.

custom loss

ber

Что может быть причиной этого?

1 Ответ

1 голос
/ 27 сентября 2019

Ваша функция потерь очень нестабильна, посмотрите на нее:

enter image description here

Где я заменил y_pred (переменная) на x и y_true (постоянная) с c для простоты.

Когда ваши прогнозы приближаются к нулю, по крайней мере одна операция будет стремиться к 1/0, что является бесконечным.Хотя по теории пределов вы можете знать, что результат в порядке, Керас не знает «целую» функцию как единое целое, он вычисляет производные на основе используемых базовых операций.

Итак, одно простое решение - это то, на которое указывает @today:

loss = K.switch(yTrue == 1, 1 - yPred, yPred)

Это точно такая же функция (различие только в том случае, если c не равно нулю или 1).

Кроме того, еще проще, для c=0 или c=1, это просто loss='mae'.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...