Согласно Док для кросс-энтропийной потери, взвешенная потеря рассчитывается путем умножения веса для каждого класса и первоначальной потери.
Однако в реализации pytorch вес класса, по-видимому, не влияет на окончательное значение потерь, если оно не установлено равным нулю. Ниже приведен код:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
Как показано в коде, вес класса, кажется, не имеет никакого эффекта, если он не установлен в 0, это поведение противоречит документации.
Обновление
Я реализовал версию взвешенной перекрестной энтропии, которая, на мой взгляд, является «правильным» способом сделать это.
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
Что я сделал, так это умножил каждый экземпляр в пакете на соответствующий вес класса. Результат все еще отличается от оригинальной реализации pytorch, что заставляет меня задуматься, как Pytorch на самом деле это реализует.