Как вернуть промежуточные градиенты (для неконечных узлов) в pytorch? - PullRequest
0 голосов
/ 22 марта 2019

Мой вопрос касается синтаксиса pytorch register_hook.

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

выходы:

tensor([2.])
tensor([4.])

этот фрагмент просто печатает градиент z w.r.t x и y соответственно.

Теперь мой (скорее всего тривиальный) вопрос - как вернуть промежуточные градиенты (а не только печать)?

UPDATE:

Похоже, что вызов retain_grad() решает проблему для конечных узлов. ех. y.retain_grad().

Однако, retain_grad, похоже, не решает ее для неконечных узлов. Есть предложения?

1 Ответ

0 голосов
/ 23 марта 2019

Я думаю, вы можете использовать эти хуки для хранения градиентов в глобальной переменной:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

Но вам, скорее всего, также нужно помнить соответствующий тензор, для которого были рассчитаны эти градиенты. В этом случае мы немного расширяемся, используя dict вместо list:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

Теперь вы можете, например, получить доступ к градиенту тензора y, просто используя grads[y]

...