AttributeError: объект 'tuple' не имеет атрибута 'dim' при подаче ввода в сеть Pytorch LSTM - PullRequest
0 голосов
/ 28 октября 2018

Я пытаюсь запустить следующий код:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out


state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

И это возвращает эту ошибку:

Traceback (most recent call last):
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>
    q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forward
    out = self.hidden2tag(out)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

Кто-нибудь знает, как это исправить?(чтобы избавиться от тензора, являющегося кортежем, чтобы его можно было подавать в сеть LSTM)

Ответы [ 2 ]

0 голосов
/ 28 октября 2018

Pytorch LSTM возвращает кортеж.
Таким образом, вы получаете эту ошибку, поскольку ваш линейный слой self.hidden2tag не может обработать этот кортеж.

Поэтому измените:

out = self.lstm(x)

на

out, states = self.lstm(x)

Это исправит вашу ошибку, разделив кортеж так, чтобы out был просто вашим выходным тензором.

out затем сохраняет скрытые состояния, тогда как statesэто еще один кортеж, который содержит последнее скрытое состояние ячейки.

Вы также можете посмотреть здесь:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

Вы получите еще одну ошибку для последней строки, так как max() также возвращает кортеж,Но это должно быть легко исправить и все же ошибка другая:)

0 голосов
/ 28 октября 2018

Сначала преобразуйте ваше состояние в массив NumPy:

state = np.array(state)

Вероятно, PyTorch отсутствует np.asarray в их API.

...