Параметр weight
используется для вычисления взвешенного результата для всех входных данных на основе их целевого класса. Если у вас есть только один вход или все входы одного и того же целевого класса, weight
не повлияет на потери.
См. Разницу между двумя входами разных целевых классов:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
x = Variable(torch.Tensor([[1.0,2.0,3.0], [1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1, 2]))
w = torch.Tensor([1.0,1.0,1.0])
res = F.cross_entropy(x,y,w)
# 0.9076
w = torch.Tensor([1.0,10.0,1.0])
res = F.cross_entropy(x,y,w)
# 1.3167