Слой torch.nn.LSTM затрудняет работу summary () и мой сеанс Colab падает - PullRequest
0 голосов
/ 26 апреля 2020

Я кодирую простую сеть LSTM, но при вызове summary () в моей модели возникает следующая ошибка:

RuntimeError: Ожидается скрытый [0] size (1, 2, 204), got (1, 16, 204) .

Здесь ниже я помещаю свой класс LSTM вместе с кодом, который дает мне ошибку.

class LSTM(nn.Module):

def __init__(self,input_size,hidden_size,batch_size):
    super(LSTM, self).__init__()

    self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True)
    self.h_n = torch.zeros(1,batch_size,hidden_size)
    self.c_n = torch.zeros(1,batch_size,hidden_size)


def forward(self, u):
    output,(self.h_n,self.c_n) = self.lstm(u,(self.h_n,self.c_n))

    return self.h_n

Следующий код дает мне ошибку времени выполнения:

from torchsummary import summary
HISTORY = 5
nx = 200
ng = 2
batchsize = 16
model_Surrogate = LSTM(nx+2*ng,nx+2*ng,batchsize).to(device)
optimizer_Surrogate = optim.Adam(model_Surrogate.parameters(), lr=1e-2)
summary(model_Surrogate, (HISTORY,nx+2*ng))

u имеет форму (пакетный размер, 5,204) , а именно (16,5204) .

На самом деле, я не знаю, откуда взялась ошибка 2 , вместо этого я должен установить размер пакета, который я установил в 16 . Я знаю, что summray () не знает u раньше времени, поэтому, возможно, 2 это просто заполнитель для реального размера пакета?

Кроме того, когда я пропускаю этап подведения итогов и непосредственно обучаю сеть, мой сеанс Colab просто терпит крах ...

Я не могу понять это, поэтому любая помощь будет оценена!

Так пока я пытался установить batch_first на False и адаптировать ввод, но ничего не изменилось.

Я также удалил параметры self.h_n и self.c_n и я получил это сообщение:

---------------------------------------------------------------------------
AttributeError
      3 optimizer_Surrogate = optim.Adam(model_Surrogate.parameters(), lr=1e-2)
----> 4 summary(model_Surrogate, (HISTORY,nx+2*ng))

5 frames
/usr/local/lib/python3.6/dist-packages/torchsummary/torchsummary.py in <listcomp>(.0)
     21             if isinstance(output, (list, tuple)):
     22                 summary[m_key]["output_shape"] = [
---> 23                     [-1] + list(o.size())[1:] for o in output
     24                 ]
     25             else:

AttributeError: 'tuple' object has no attribute 'size'
...