PyTorch: CrossEntropyLoss, изменение веса класса не меняет вычисленные потери - PullRequest
0 голосов
/ 19 января 2019

Согласно Док для кросс-энтропийной потери, взвешенная потеря рассчитывается путем умножения веса для каждого класса и первоначальной потери.

Однако в реализации pytorch вес класса, по-видимому, не влияет на окончательное значение потерь, если оно не установлено равным нулю. Ниже приведен код:

from torch import nn
import torch

logits = torch.FloatTensor([
    [0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0

Как показано в коде, вес класса, кажется, не имеет никакого эффекта, если он не установлен в 0, это поведение противоречит документации.

Обновление
Я реализовал версию взвешенной перекрестной энтропии, которая, на мой взгляд, является «правильным» способом сделать это.

import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
    assert len(logits.size()) == 2
    batch_size, label_num = logits.size()
    assert (batch_size == label.size(0))

    if weight is None:
        weight = torch.ones(label_num).float()

    assert (label_num == weight.size(0))

    x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
    log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

    weights = torch.gather(weight, 0, label).float()

    return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
    [0.1, 0.9],
    [0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075

Что я сделал, так это умножил каждый экземпляр в пакете на соответствующий вес класса. Результат все еще отличается от оригинальной реализации pytorch, что заставляет меня задуматься, как Pytorch на самом деле это реализует.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...