Устройство Pytorch и метод .to (устройство) - PullRequest
1 голос
/ 17 марта 2020

Я пытаюсь выучить RNN и Pytorch.

Итак, я увидел несколько кодов для RNN, где в методе прямого пробагации они сделали такую ​​проверку:

def forward(self, inputs, hidden):
    if inputs.is_cuda:
        device = inputs.get_device()
    else:
        device = torch.device("cpu")
    embed_out = self.embeddings(inputs)
    logits = torch.zeros(self.seq_len, self.batch_size, self.vocab_size).to(device)

Я думаю, что смысл проверки в том, чтобы посмотреть, сможем ли мы запустить код на более быстром GPU вместо CPU? Чтобы понять код немного больше, я сделал следующее:

ex= torch.zeros(3,10,5)
ex1= torch.tensor(np.array([[0,0,0,1,0], [1,0,0,0,0],[0,1,0,0,0]]))

print(ex)
print("device is")
print(ex1.get_device())
print(ex.to(ex1.get_device()))

И вывод был:

        ...
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
device is
-1
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-b09342e2ba0f> in <module>()
     67 print("device is")
     68 print(ex1.get_device())
---> 69 print(ex.to(ex1.get_device()))

RuntimeError: Device index must not be negative

Я не понимаю "устройство" в коде, и я не понимаю .to(device) метод. Можете ли вы помочь мне понять это?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...