Почему этот параметр функции идентичен при каждом вызове, несмотря на передачу разных значений?(Создание замыканий в цикле) - PullRequest
0 голосов
/ 17 ноября 2018

Я использую PyTorch и пытаюсь зарегистрировать хуки для параметров модели.Следующий код создает лямбда-функции для добавления к каждому параметру модели, поэтому я могу видеть в ловушке, какой тензор градиент принадлежит

import torch
import torchvision

# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224)   # batch of 10 images
targets = torch.zeros(10).long()

def grad_hook_template(param, name, grad):
    print(f'Receive grad for {name} w whape {grad.shape}')

# add one lambda hook to each parameter
for name, param in model.named_parameters():
    print(f'Register hook for {name}')

    # use a lambda so we can pass additional information to the hook, which should only take one parameter
    param.register_hook(lambda grad: grad_hook_template(param, name, grad))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()

prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()

В результате name и param аргументы grad_hook_template всегда как одно и то же значение (и id), но аргумент grad всегда отличается (как и ожидалось).Почему при регистрации хука лямбды, кажется, каждый раз ссылаются на одни и те же локальные переменные?

Я прочитал, например, здесь , что циклы не создают новых областей, а замыкания являются лексическимив Python, то есть name и param, которые я передаю лямбде, являются просто указателями, и любое значение, которое они имеют в конце цикла, видят все с этим указателем.Но что я могу поделать?copy.copy() переменные?

Ответы [ 2 ]

0 голосов
/ 17 ноября 2018

Вы столкнулись с запаздывающими закрытиями . Переменные param и name ищутся во время вызова, а не когда определяется функция, в которой они используются. К моменту вызова любой из этих функций name и param находятся на последних значениях в цикле. Чтобы обойти это, вы можете сделать это:

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(lambda grad, name=name, param=param: grad_hook_template(param, name, grad))

Тем не менее, я думаю, что использование functools.partial является правильным решением здесь:

from functools import partial

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(partial(grad_hook_template, name=name, param=param))

Более подробную информацию о поздних привязках можно найти на странице Common Gotchas Руководства автостопщика по Python , а также в документации по Python .

Обратите внимание, что это в равной степени относится к функциям, определенным ключевым словом def.

0 голосов
/ 17 ноября 2018

Это своего рода ответ FAQ .

Решения включают

  • с использованием functools.partial вместо lambda
  • использование параметров по умолчанию для лямбда-выражений для захвата значения переменных
...