Я использую LSTM в моей модели, и моя функция пересылки определена следующим образом:
self.rnn = nn.LSTM(...)
def forward(self, input, hidden):
# dropout, linear embedding
emb = self.drop(self.encoder(input.contiguous().view(-1, self.num_in_features)))
emb = emb.view(-1, input.size(1), self.rnn_hid_size)
output, hidden = nn.LSTM(emb, hidden)
output = self.drop(output)
# [(seq_len x batch_size) * feature_size]
decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
# [ seq_len, batch_size, feature_size]
decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))
return decoded, hidden, output
Я делаю некоторые выводы следующим образом:
model.eval()
hidden = model.init_hidden(1)
with torch.no_grad():
for i in range(end_point):
if i >= start_point:
out, hidden, _ = model.forward(out, hidden)
else:
out, hidden, _ = model.forward(my_input_seq[i].unsqueeze(0), hidden)
Теперь, после запуска оператор
out, hidden, _ = model.forward(out, hidden)
После получения вывода я хочу отменить этот оператор, т.е. восстановить состояние LSTM до вызова. Я предполагаю, что это будет означать как-то отменить или восстановить скрытое состояние до вызова. Какой быстрый (и, надеюсь, простой) способ добиться этого в pytorch?