PyTorch: _thnn_nll_loss_forward не реализован для типа torch.LongTensor - PullRequest
1 голос
/ 30 апреля 2019

При попытке создать модель с использованием PyTorch, когда я пытаюсь реализовать функцию потерь nll_loss, она выдает следующую ошибку

RuntimeError: _thnn_nll_loss_forward is not implemented for type torch.LongTensor 

Созданная мной функция подгонки:

for epoch in tqdm_notebook(range(1, epochs+1)):
    for batch_idx, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        net.float()
        output = net(data)
        output_x = output.argmax(dim=2) #to convert (64,50,43) -> (64, 50)
        loss = F.nll_loss(output_x, targets)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train epochs: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx*len(data), len(ds.data),
                100.*batch_idx / len(ds), loss.item()
            ))

Там, где форма выходных данных и целей равна (64, 50) , а dtypes для обоих * torch.int64.

Ответы [ 2 ]

3 голосов
/ 30 апреля 2019

Посмотрите на описание из F.nll_loss. Ожидается получить в качестве входных данных не argmax прогноза (тип torch.long), а скорее полные векторы предсказания 64x50x43 (типа torch.float). Обратите внимание, что на самом деле предсказание, которое вы предоставляете F.nll_loss, имеет дополнительное измерение больше, чем целевые истинные цели, которые вы предоставляете.

В вашем случае просто удалите argmax:

loss = F.nll_loss(output, targets)
2 голосов
/ 30 апреля 2019

Похоже, что вы обрабатываете задачу классификации с 43 классами, используя размер пакета 64 с "длиной последовательности", равной 50.

Если это так, я считаю, что вы немногопутать с использованием argmax() или F.log_softmax.Поскольку Шай дал ссылку, учитывая, что output является логит-значением, вы можете использовать:

output_x = F.log_softmax(output, dim=2)
loss = F.nll_loss(output_x, targets)

Это правильный способ использования nll_loss, или если вы не хотите делать log_softmaxВы можете использовать nn.CrossEntropyLoss вместо.

...