восстановить скрытое состояние сети LSTM в pytorch - PullRequest
1 голос
/ 27 февраля 2020

Я использую LSTM в моей модели, и моя функция пересылки определена следующим образом:


self.rnn = nn.LSTM(...)

def forward(self, input, hidden):
   #  dropout, linear embedding
    emb = self.drop(self.encoder(input.contiguous().view(-1, self.num_in_features)))  
    emb = emb.view(-1, input.size(1), self.rnn_hid_size)

    output, hidden = nn.LSTM(emb, hidden)
    output = self.drop(output)
    # [(seq_len x batch_size) * feature_size]
    decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
    # [ seq_len, batch_size, feature_size]
    decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))

    return decoded, hidden, output

Я делаю некоторые выводы следующим образом:

model.eval()
hidden = model.init_hidden(1)


with torch.no_grad():
    for i in range(end_point):
        if i >= start_point:
            out, hidden, _ = model.forward(out, hidden)
       else:
            out, hidden, _ = model.forward(my_input_seq[i].unsqueeze(0), hidden)

Теперь, после запуска оператор

out, hidden, _ = model.forward(out, hidden)

После получения вывода я хочу отменить этот оператор, т.е. восстановить состояние LSTM до вызова. Я предполагаю, что это будет означать как-то отменить или восстановить скрытое состояние до вызова. Какой быстрый (и, надеюсь, простой) способ добиться этого в pytorch?

...