Как правильно рассчитать потери мульти-меток в F.nll_loss у pytorch при работе с бесполезными метками? - PullRequest
0 голосов
/ 12 июня 2019

Представьте себе, что мой номер метки равен 100, но в наборе данных некоторые данные повреждены, поэтому я устанавливаю его на 0 в его представлении в одно касание, когда я использую 'motor ll_loss', я принимаю местоположение повреждения в качестве значения - 1, таким образом, количество меток равно 101, и, таким образом, появилась ошибка, есть ли способ, как тензорный поток правильно обрабатывает вышеуказанные данные о повреждениях, а не очищает данные о повреждениях?

logits = F.log_softmax(torch.randn(5, 100), dim=1)
idx_train = torch.as_tensor([1, 2, 3]).long()
idx_train_labels = torch.as_tensor([0, 4, 2]).long()
fail_idx_train_labels = torch.as_tensor([2, 4, 101]).long()

# right
F.nll_loss(logits[idx_train], idx_train_labels) 
# RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. 
F.nll_loss(logits[idx_train], fail_idx_train_labels) 
...