Я думаю, вы можете использовать эти хуки для хранения градиентов в глобальной переменной:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
Но вам, скорее всего, также нужно помнить соответствующий тензор, для которого были рассчитаны эти градиенты. В этом случае мы немного расширяемся, используя dict
вместо list
:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
Теперь вы можете, например, получить доступ к градиенту тензора y
, просто используя grads[y]