Пользовательская функция потерь Pytorch (функция потерь с шумоподавлением) - PullRequest
0 голосов
/ 23 октября 2019

Я пытаюсь написать пользовательскую функцию потерь ( Noise Reduction Loss ) в PyTorch. Это очень похоже на кросс-энтропийную потерю с той разницей, что она дает некоторую уверенность в предсказанном ответе (наибольшая вероятность в предсказанной матрице), предполагая, что некоторые метки неверны в обучающих данных. Здесь pred представляет прогнозируемую [m * L] матрицу, где m - количество примеров, а L - количество меток, y_true - [m * 1] матрица фактических меток, а «ro» - гиперпараметр, определяющий влияние каждого из них. из двух используемых критериев.

def lossNR(pred, y_true, ro):
    outputs = torch.log(pred)   # compute the log of softmax values
    out1 = outputs.gather(1, y_true.view([-1,1])) # pick the values corresponding to the labels
    l1 = -((ro)* torch.mean(out1))
    l2 = -(1-ro) * torch.mean((torch.max(outputs,1)[0]))
    print("l1=", l1)
    print("l2 = ", l2)
    return (l1+l2)

Я пробовал функцию потерь на различных наборах данных, но она не работает ни на что. Пожалуйста, предоставьте предложения.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...