Функция потери PyTorch, когда метка представляет собой список? - PullRequest
0 голосов
/ 14 марта 2019

[Помогите отправить сообщение другу] У меня есть модель, которая возвращает двоичную последовательность предсказаний длины k, например, [0, 0.2, 0.6, 0.4, 0.8], и у меня есть метки типа [0, 1, 1, 0, 0].Как я могу определить функцию потерь здесь?

1 Ответ

0 голосов
/ 14 марта 2019

Если это бинарная классификация и тензор предсказаний взят из сигмоидальной функции, тогда вы можете использовать torch.nn.BCELoss двоичную кросс-энтропийную потерю.Если вы не применяете сигмоид / софтмакс для тензора прогнозов, то вам лучше использовать torch.nn.BCEWithLogitsLoss

В PyTorch функция потерь называется критерием, и вы можете определить их как:

criterion = nn.BCELoss()
loss = criterion(prediction, target)
...