Функция Custom Loss в TensorFlow со ступенчатыми потерями - PullRequest
0 голосов
/ 03 августа 2020

Я написал специальную функцию потерь, которая больше всего наказывает за неправильный знак прогноза, квадратную ошибку для больших и абс потерь для небольших различий:


class CustomLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, y_true, y_pred):
        error = y_true - y_pred
        wrong_direction = y_true * y_pred < 0
        small_diff = tf.abs(error) < 0.005

        large_loss = 10 * tf.square(error)
        square_loss = tf.square(error)
        linear_loss = tf.abs(error)

        return tf.where(wrong_direction, large_loss, tf.where(small_diff, linear_loss, square_loss))

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

Это работает нормально, но обучение очень медленный. Я подумал, что это может быть из-за того, что я не использовал tf.cond (), поэтому попытался реализовать его таким образом, но он не работает:

linear_loss = tf.cond(tf.abs(error) < 0.005, tf.abs(error), tf.square(error))

return tf.where(wrong_direction, large_loss, linear_loss)

Есть идеи, как его правильно реализовать или есть это повлияет на время тренировки?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...