Ошибка Pytorch, RuntimeError: ожидаемый скалярный тип Long, но обнаружен Double - PullRequest
1 голос
/ 16 июня 2020

Я столкнулся со следующей ошибкой при обучении классификатора BERT.

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

Что может быть здесь возможной ошибкой типа данных, поскольку я не привел ни одну переменную к типу long или double?

Заранее спасибо! Error Stack Trace

Ответы [ 2 ]

1 голос
/ 16 июня 2020

В задаче классификации тип данных для меток ввода должен быть Long, но вы назначили им float64

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)
0 голосов
/ 16 июня 2020

Вы можете использовать torch.Tensor.long для преобразования тензора в ожидаемый тип long.

# Here, you can pass parameter like this in your call
..., labels = b_labels.long())
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...