Что касается взвешивания на ANN с 3 выходами с пользовательскими потерями - PullRequest
0 голосов
/ 11 февраля 2019

Я пытаюсь определить пользовательскую функцию потерь, которая принимает 3 выходные переменные в регрессионной модели.

def custom_loss(y_true, y_pred):
    y_true_c = K.cast(y_true, 'float32')  # Shape=(batch_size, 3)
    y_pred_c = K.cast(y_pred, 'float32')  # Shape=(batch_size, 3)

    # Compute error
    num = K.abs(y_true_c - y_pred_c)  # Shape=(batch_size, 3)
    den = K.maximum(y_true_c, y_pred_c)   # Shape=(batch_size, 3)
    err = K.sum(num / den, axis=-1)  # Shape=(batch_size,)

    # Output loss
    return K.mean(err)

Как мне взвесить 3 потери, заданные 3 выходами, до их суммирования взначение единственного убытка?

Мой оператор model.compile () в настоящее время:

model.compile(loss=custom_loss, metrics=['mse'],optimizer=optimizer, loss_weights=[0.25,0.5,0.25])

, где я пытаюсь взвесить их 0,25, 0,5, 0,25 (суммы до 1) для каждогоиз 3 выходов соответственно.Однако я думаю, что этот инструмент может не работать с пользовательской функцией потерь.

Как мне этого добиться?

1 Ответ

0 голосов
/ 11 февраля 2019

Вы можете передать дополнительный аргумент weights своему пользовательскому проигрышу следующим образом:

def custom_loss(weights):
    def loss(y_true, y_pred):
        y_true_c = K.cast(y_true, 'float32')  # Shape=(batch_size, 3)
        y_pred_c = K.cast(y_pred, 'float32')  # Shape=(batch_size, 3)

        # Compute error
        num = K.abs(y_true_c - y_pred_c)  # Shape=(batch_size, 3)
        den = K.maximum(y_true_c, y_pred_c)  # Shape=(batch_size, 3)
        aux = weights * (num / den)  # Shape=(batch_size, 3)
        err = K.sum(aux, axis=-1)  # Shape=(batch_size,)

        # Output loss
        return K.mean(err)

    return loss

И затем скомпилируйте свою модель, как показано ниже:

# weights shape is (3,)
weights = np.array([0.25, 0.5, 0.25])
model.compile(loss=custom_loss(weights), metrics=['mse'], optimizer=optimizer)

ПРИМЕЧАНИЕ: Не проверено.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...