RuntimeError: Ожидаемый объект скалярного типа Double, но получил скалярный тип Float для аргумента № 2 - PullRequest
0 голосов
/ 02 марта 2020

У меня есть модель PyTorch LSTM, и моя функция forward выглядит следующим образом:

    def forward(self, x, hidden):
        print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

Все операторы print показывают torch.float64, что я считаю двойным. Итак, почему я получаю эту проблему?

Я уже набрал double во всех соответствующих местах.

1 Ответ

1 голос
/ 02 марта 2020

Убедитесь, что ваши данные и модель указаны в dtype double.

Для модели:

net = net.double()

Для данных:

net(x.double)

Это было обсуждено на форуме PyTorch .

...