Pytorch не поддерживает один горячий вектор? - PullRequest
1 голос
/ 06 апреля 2019

Я очень озадачен тем, как Pytorch работает с горячими векторами.В этом уроке нейронная сеть будет генерировать вектор с одним горячим током в качестве вывода.Насколько я понимаю, схематическая структура нейронной сети в учебнике должна быть такой:

enter image description here

Однако labels не находятся вГорячий векторный формат.Я получаю следующее size

print(labels.size())
print(outputs.size())

output>>> torch.Size([4]) 
output>>> torch.Size([4, 10])

Чудесным образом, я передаю outputs и labels criterion=CrossEntropyLoss(), ошибки вообще нет.

loss = criterion(outputs, labels) # How come it has no error?

МойГипотеза:

Возможно, Pytorch автоматически преобразует labels в векторную форму.Итак, я пытаюсь преобразовать метки в один горячий вектор перед передачей его в функцию потерь.

def to_one_hot_vector(num_class, label):
    b = np.zeros((label.shape[0], num_class))
    b[np.arange(label.shape[0]), label] = 1

    return b

labels_one_hot = to_one_hot_vector(10,labels)
labels_one_hot = torch.Tensor(labels_one_hot)
labels_one_hot = labels_one_hot.type(torch.LongTensor)

loss = criterion(outputs, labels_one_hot) # Now it gives me error

Однако я получил следующую ошибку

RuntimeError: multi-targetне поддерживается в /opt/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

Таким образом, горячие векторы не поддерживаются в Pytorch?Как Pytorch вычисляет cross entropy для двух тензоров outputs = [1,0,0],[0,0,1] и labels = [0,2]?Это сейчас не имеет никакого смысла для меня.

1 Ответ

4 голосов
/ 06 апреля 2019

Я в замешательстве из-за вашей путаницы. PyTorch четко заявляет в своей документации для CrossEntropyLoss, что

Этот критерий предполагает индекс класса (от 0 до C-1) в качестве цели для каждого значения миниатюры с 1-мерным тензорным размером

Другими словами, он имеет вашу концепцию to_one_hot_vector, концептуально встроенную в CEL, и не предоставляет API-интерфейс one-hot. Обратите внимание, что горячие векторы неэффективны по сравнению с хранением меток классов.

Если вам даны горячие векторы и вам нужно перейти к формату меток классов (например, для совместимости с CEL), вы можете использовать argmax, как показано ниже:

import torch

labels = torch.tensor([1, 2, 3, 5])
one_hot = torch.zeros(4, 6)
one_hot[torch.arange(4), labels] = 1

reverted = torch.argmax(one_hot, dim=1)
assert (labels == reverted).all().item()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...