В pytorch, как использовать весовой параметр в F.cross_entropy ()? - PullRequest
0 голосов
/ 09 мая 2018

Я пытаюсь написать код, подобный приведенному ниже:

x = Variable(torch.Tensor([[1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1]))
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w)

Тем не менее, выход перекрестной энтропийной потери всегда равен 1.4076 независимо от того, w . Что стоит за весовым параметром для F.cross_entropy ()? Как правильно его использовать?
Я использую pytorch 0.3

1 Ответ

0 голосов
/ 09 мая 2018

Параметр weight используется для вычисления взвешенного результата для всех входных данных на основе их целевого класса. Если у вас есть только один вход или все входы одного и того же целевого класса, weight не повлияет на потери.

См. Разницу между двумя входами разных целевых классов:

import torch
import torch.nn.functional as F
from torch.autograd import Variable

x = Variable(torch.Tensor([[1.0,2.0,3.0], [1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1, 2]))
w = torch.Tensor([1.0,1.0,1.0])
res = F.cross_entropy(x,y,w)
# 0.9076
w = torch.Tensor([1.0,10.0,1.0])
res = F.cross_entropy(x,y,w)
# 1.3167
...