Я кодирую простую сеть LSTM, но при вызове summary () в моей модели возникает следующая ошибка:
RuntimeError: Ожидается скрытый [0] size (1, 2, 204), got (1, 16, 204) .
Здесь ниже я помещаю свой класс LSTM вместе с кодом, который дает мне ошибку.
class LSTM(nn.Module):
def __init__(self,input_size,hidden_size,batch_size):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size,hidden_size,batch_first=True)
self.h_n = torch.zeros(1,batch_size,hidden_size)
self.c_n = torch.zeros(1,batch_size,hidden_size)
def forward(self, u):
output,(self.h_n,self.c_n) = self.lstm(u,(self.h_n,self.c_n))
return self.h_n
Следующий код дает мне ошибку времени выполнения:
from torchsummary import summary
HISTORY = 5
nx = 200
ng = 2
batchsize = 16
model_Surrogate = LSTM(nx+2*ng,nx+2*ng,batchsize).to(device)
optimizer_Surrogate = optim.Adam(model_Surrogate.parameters(), lr=1e-2)
summary(model_Surrogate, (HISTORY,nx+2*ng))
u имеет форму (пакетный размер, 5,204) , а именно (16,5204) .
На самом деле, я не знаю, откуда взялась ошибка 2 , вместо этого я должен установить размер пакета, который я установил в 16 . Я знаю, что summray () не знает u раньше времени, поэтому, возможно, 2 это просто заполнитель для реального размера пакета?
Кроме того, когда я пропускаю этап подведения итогов и непосредственно обучаю сеть, мой сеанс Colab просто терпит крах ...
Я не могу понять это, поэтому любая помощь будет оценена!
Так пока я пытался установить batch_first на False и адаптировать ввод, но ничего не изменилось.
Я также удалил параметры self.h_n и self.c_n и я получил это сообщение:
---------------------------------------------------------------------------
AttributeError
3 optimizer_Surrogate = optim.Adam(model_Surrogate.parameters(), lr=1e-2)
----> 4 summary(model_Surrogate, (HISTORY,nx+2*ng))
5 frames
/usr/local/lib/python3.6/dist-packages/torchsummary/torchsummary.py in <listcomp>(.0)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
AttributeError: 'tuple' object has no attribute 'size'