Причина, по которой это работает без retain_graph=True
, в вашем случае состоит в том, что у вас очень простой график, который, вероятно, не будет иметь внутренних промежуточных буферов, в свою очередь, никакие буферы не будут освобождены, поэтому нет необходимости использовать retain_graph=True
.
Но все меняется при добавлении еще одного дополнительного вычисления к вашему графику:
Код:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2))
print('Backward 1st time w/o retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/o retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Вывод:
Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
[3., 3.]]).
В этом случаедополнительный внутренний v.grad
будет вычислен, но torch
не хранит промежуточные значения (промежуточные градиенты и т. д.), а с retain_graph=False
v.grad
будет освобожден после первого backward
.
Итак,если вы хотите выполнить обратное преобразование во второй раз, вам нужно указать retain_graph=True
, чтобы «сохранить» график.
Код:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2), retain_graph=True)
print('Backward 1st time w/ retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/ retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Вывод:
Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
[6., 6.]])