Ниже приведена ошибка, которую я получаю при обучении RNN:
RuntimeError: Попытка перевернуть график во второй раз, но буферы уже освобождены. Укажите retain_graph = True при обратном вызове в первый раз.
Я попробовал решение, данное здесь
Я не понимаю, почему это происходит. Выполнение .backward(retain_graph=True)
устраняет ошибку, но тогда потеря не уменьшается независимо от того, сколько эпох.
Возможно, потому, что я не совсем понимаю «часть отрыва» в вычислительном графе.
RNN Config:
batch_size = 1
input_size = 1
sequence_length = 10
hidden_size = 1
num_layer = 10
RNN Класс:
class ModelRnn(nn.Module):
def __init__(self):
super(ModelRnn, self).__init__()
self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True, num_layers=num_layer)
def forward(self, x, hidden):
x = x.view(batch_size, sequence_length, input_size)
out, hidden = self.rnn(x, hidden)
return hidden, out
def init_hidden(self):
hidden_state = Variable(torch.zeros(num_layer, batch_size, hidden_size))
return (hidden_state)
Цикл поезда:
hidden = model.init_hidden()
for epoch in range(5):
for x_batch, y_batch in train_loader:
model.zero_grad()
hidden, output = model(x_batch, hidden)
optimizer.zero_grad()
loss = criterion(output, y_batch)
print(f"{epoch+1} epoch | loss = {loss}")
loss.backward()
optimizer.step()
Я получаю следующую ошибку:
1 epoch | loss = 96040000.0
2 epoch | loss = 96040000.0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-90-8a1b1c3e9bec> in <module>()
9
10 print(f"{epoch+1} epoch | loss = {loss}")
---> 11 loss.backward()
12 optimizer.step()
~/anaconda3/envs/arena-py3.6/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
100 products. Defaults to ``False``.
101 """
--> 102 torch.autograd.backward(self, gradient, retain_graph, create_graph)
103
104 def register_hook(self, hook):
~/anaconda3/envs/arena-py3.6/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
88 Variable._execution_engine.run_backward(
89 tensors, grad_tensors, retain_graph, create_graph,
---> 90 allow_unreachable=True) # allow_unreachable flag
91
92
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Модель должна тренироваться без заминки, но мне здесь чего-то не хватает, чего я не могу понять.