У меня есть:
def forward(self, x, hidden=None):
lstm_out, hidden = self.lstm(x, hidden)
print('lstm_out.size', lstm_out.size())
lstm_out = lstm_out.view(-1, lstm_out.shape[2])
out = self.linear(lstm_out)
print('out', out.size())
И это не работает. Мой self.linear
- self.linear = nn.Linear(64 * seq_length, 5)
. Я могу изменить 5 на что угодно позже.
Итак, size
моего lstm_out
равно torch.Size([64, 20, 322])
. Но тогда я получаю ошибку при выполнении self.linear
: RuntimeError: size mismatch, m1: [1280 x 322], m2: [1280 x 5] at
. Что я делаю не так?