Я написал специальную функцию потерь, которая больше всего наказывает за неправильный знак прогноза, квадратную ошибку для больших и абс потерь для небольших различий:
class CustomLoss(keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, y_true, y_pred):
error = y_true - y_pred
wrong_direction = y_true * y_pred < 0
small_diff = tf.abs(error) < 0.005
large_loss = 10 * tf.square(error)
square_loss = tf.square(error)
linear_loss = tf.abs(error)
return tf.where(wrong_direction, large_loss, tf.where(small_diff, linear_loss, square_loss))
def get_config(self):
base_config = super().get_config()
return {**base_config}
Это работает нормально, но обучение очень медленный. Я подумал, что это может быть из-за того, что я не использовал tf.cond (), поэтому попытался реализовать его таким образом, но он не работает:
linear_loss = tf.cond(tf.abs(error) < 0.005, tf.abs(error), tf.square(error))
return tf.where(wrong_direction, large_loss, linear_loss)
Есть идеи, как его правильно реализовать или есть это повлияет на время тренировки?