Да. Это будет работать. PyTorch создает вычислительный граф с узлами для всех операций, которые вы выполняете во время прямого прохода. Таким образом, для каждой операции в for l oop, независимо от того, как вы сохраняете выходные данные (или даже если вы их отбрасываете), вы все равно должны иметь возможность вызвать обратный вызов при потере, поскольку информация, необходимая для вычисления градиентов, уже существует. на графике.
Вот игрушечный пример:
import torch
from torch import nn
import torch.nn.functional as F
class RNN(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.wh = nn.Linear(hidden_dim, hidden_dim)
self.wx = nn.Linear(in_dim, hidden_dim)
self.wo = nn.Linear(hidden_dim, out_dim)
self.hidden_dim = hidden_dim
def forward(self, xs, hidden):
outputs = []
hiddens = []
hiddens.append(hidden)
for i, x in enumerate(xs):
hiddens.append(torch.tanh(self.wx(x) + self.wh(hiddens[i])))
outputs.append(F.log_softmax(self.wo(hiddens[i+1]), dim=0))
return outputs, hiddens
def init_hidden(self):
return torch.zeros(self.hidden_dim)
# Initialize the input and output
x = torch.tensor([[1., 0., 0.], [0., 1., 0.], [1., 0., 0.]])
y = torch.tensor([1])
# Initialize the network, hidden state and loss function
rnn = RNN(3, 10, 2)
hidden = rnn.init_hidden()
nll = nn.NLLLoss()
# Forward pass
outputs, hidden_states = rnn(x, hidden)
# Compute loss
loss = nll(outputs[-1].unsqueeze(0), y)
# Call Backward on the loss
loss.backward()
# Inspect the computed gradients
print(rnn.wh.weight.grad) # Gives a tensor of shape (10, 10) in this case