Pytorch: обратное дифференцирование по функции гессиана: « - PullRequest
0 голосов
/ 21 мая 2019

Я только что закончил отладку сумасшедшей ошибки в моем коде pytorch, поэтому я решил поделиться своим опытом.

В качестве минимального примера, вот код, который вылетает:

import torch
ts = torch.tensor
t = ts([i+.5 for i in range(5)], requires_grad=True) #some tensor
f = torch.sum(t**3) #some function of t
[lg] = torch.autograd.grad(f,t,create_graph=True,retain_graph=True) #the gradient of f wrt t
[lgg] = torch.autograd.grad(lg[1],t,create_graph=True,retain_graph=True) #one row of the Hessian of f wrt t
lgg2 = lgg.contiguous().view(-1)
lgg3 = torch.zeros(6,6,requires_grad=True) #if we were calculating the full Hessian, it would be a matrix...
lgg3[1,1:].add_(lgg2.type_as(lgg3)) #...and we'd fill it in row by row 
l2 = torch.sum(lgg3) #a loss that's a function of the Hessian
l2.backward() #we want the gradient of that loss
print(t.grad)
print(lgg3.requires_grad)

Это дает ошибку:

RuntimeError: leaf variable has been moved into the graph interior

Как это исправить? Я дам вам ответ ...

1 Ответ

0 голосов
/ 21 мая 2019

Не устанавливайте requires_grad при создании lgg3. В конце этого кода для lgg3 все равно будет установлено значение True, но это останавливает ошибку. Это вуду, но это работает.

import torch
ts = torch.tensor
t = ts([i+.5 for i in range(5)], requires_grad=True) #some tensor
f = torch.sum(t**3) #some function of t
[lg] = torch.autograd.grad(f,t,create_graph=True,retain_graph=True) #the gradient of f wrt t
[lgg] = torch.autograd.grad(lg[1],t,create_graph=True,retain_graph=True) #one row of the Hessian of f wrt t
lgg2 = lgg.contiguous().view(-1)
lgg3 = torch.zeros(6,6) #if we were calculating the full Hessian, it would be a matrix...
lgg3[1,1:].add_(lgg2.type_as(lgg3)) #...and we'd fill it in row by row 
l2 = torch.sum(lgg3) #a loss that's a function of the Hessian
l2.backward() #we want the gradient of that loss
print(t.grad)
print(lgg3.requires_grad) #look, it's True.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...