Проблема с nn.embroduction в pytorch, ожидаемый скалярный тип Long, но получил torch.cuda.FloatTensor (как исправить)? - PullRequest
0 голосов
/ 14 октября 2019

поэтому у меня есть кодер RNN, который является частью более крупной языковой модели, где процесс кодируется -> rnn -> декодировать.

Как часть моего init для моего класса rnn, у меня есть следующее:

self.encode_this = nn.Embedding(self.vocab_size, self.embedded_vocab_dim)

Теперь я пытаюсь реализовать класс forward, который принимает пакеты ивыполняет кодирование, а затем декодирование,

def f_calc(self, batch):

    #Here, batch.shape[0] is the size of batch while batch.shape[1] is the sequence length

    hidden_states = (torch.zeros(self.num_layers, batch.shape[0], self.hidden_vocab_dim).to(device))
    embedded_states = (torch.zeros(batch.shape[0],batch.shape[1], self.embedded_vocab_dim).to(device))

    o1, h = self.encode_this(embedded_states)

однако моя проблема всегда связана с кодировщиком, который выдает мне следующую ошибку:

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1465         # remove once script supports set_grad_enabled
   1466         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1467     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1468 
   1469 

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

Кто-нибудь знает, как это исправить? Я совершенно новичок в Pytorch, поэтому, пожалуйста, извините, если это глупый вопрос. Я знаю, что есть какая-то форма приведения типов, но я не знаю, как это сделать ...

очень ценится!

1 Ответ

0 голосов
/ 14 октября 2019

Встраиваемый слой ожидает целые числа на входе.

import torch as t

emb = t.nn.Embedding(embedding_dim=3, num_embeddings=26)

emb(t.LongTensor([0,1,2]))

enter image description here

Добавьте long() в ваш код:

embedded_states = (torch.zeros(batch.shape[0],batch.shape[1], self.embedded_vocab_dim).to(device)).long()
...