Я пытаюсь выучить RNN и Pytorch.
Итак, я увидел несколько кодов для RNN, где в методе прямого пробагации они сделали такую проверку:
def forward(self, inputs, hidden):
if inputs.is_cuda:
device = inputs.get_device()
else:
device = torch.device("cpu")
embed_out = self.embeddings(inputs)
logits = torch.zeros(self.seq_len, self.batch_size, self.vocab_size).to(device)
Я думаю, что смысл проверки в том, чтобы посмотреть, сможем ли мы запустить код на более быстром GPU вместо CPU? Чтобы понять код немного больше, я сделал следующее:
ex= torch.zeros(3,10,5)
ex1= torch.tensor(np.array([[0,0,0,1,0], [1,0,0,0,0],[0,1,0,0,0]]))
print(ex)
print("device is")
print(ex1.get_device())
print(ex.to(ex1.get_device()))
И вывод был:
...
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
device is
-1
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-b09342e2ba0f> in <module>()
67 print("device is")
68 print(ex1.get_device())
---> 69 print(ex.to(ex1.get_device()))
RuntimeError: Device index must not be negative
Я не понимаю "устройство" в коде, и я не понимаю .to(device)
метод. Можете ли вы помочь мне понять это?