Я использую PyTorch и пытаюсь зарегистрировать хуки для параметров модели.Следующий код создает лямбда-функции для добавления к каждому параметру модели, поэтому я могу видеть в ловушке, какой тензор градиент принадлежит
import torch
import torchvision
# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224) # batch of 10 images
targets = torch.zeros(10).long()
def grad_hook_template(param, name, grad):
print(f'Receive grad for {name} w whape {grad.shape}')
# add one lambda hook to each parameter
for name, param in model.named_parameters():
print(f'Register hook for {name}')
# use a lambda so we can pass additional information to the hook, which should only take one parameter
param.register_hook(lambda grad: grad_hook_template(param, name, grad))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()
В результате name
и param
аргументы grad_hook_template
всегда как одно и то же значение (и id
), но аргумент grad
всегда отличается (как и ожидалось).Почему при регистрации хука лямбды, кажется, каждый раз ссылаются на одни и те же локальные переменные?
Я прочитал, например, здесь , что циклы не создают новых областей, а замыкания являются лексическимив Python, то есть name
и param
, которые я передаю лямбде, являются просто указателями, и любое значение, которое они имеют в конце цикла, видят все с этим указателем.Но что я могу поделать?copy.copy()
переменные?