pyTorch может вернуться назад дважды без установки retain_graph = True - PullRequest
0 голосов
/ 23 сентября 2018

Как указано в PyTorch Tutorial ,

, если вы даже хотите сделать обратную часть в некоторой части графика дважды, вам нужно передать retain_graph = True во времяпервый проход.

Однако я обнаружил, что следующий фрагмент кода фактически работает без этого.Я использую pyTorch-0.4

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad

вывод:

tensor([[ 2.,  2.], 
        [ 2.,  2.]]) 

Может кто-нибудь объяснить?Заранее спасибо!

1 Ответ

0 голосов
/ 07 апреля 2019

Причина, по которой это работает без retain_graph=True, в вашем случае состоит в том, что у вас очень простой график, который, вероятно, не будет иметь внутренних промежуточных буферов, в свою очередь, никакие буферы не будут освобождены, поэтому нет необходимости использовать retain_graph=True.

Но все меняется при добавлении еще одного дополнительного вычисления к вашему графику:

Код:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2))

print('Backward 1st time w/o retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/o retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)

print('x.grad:', x.grad)

Вывод:

Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
                [3., 3.]]).

В этом случаедополнительный внутренний v.grad будет вычислен, но torch не хранит промежуточные значения (промежуточные градиенты и т. д.), а с retain_graph=False v.grad будет освобожден после первого backward.

Итак,если вы хотите выполнить обратное преобразование во второй раз, вам нужно указать retain_graph=True, чтобы «сохранить» график.

Код:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2), retain_graph=True)

print('Backward 1st time w/ retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/ retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)
print('x.grad:', x.grad)

Вывод:

Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
                [6., 6.]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...