В PyTorch, как сделать определенный модуль `Parameters` статическим во время обучения? - PullRequest
0 голосов
/ 02 мая 2019

Контекст:

В pytorch любой Parameter является особым видом Tensor. Parameter автоматически регистрируется методом parameters() модуля, когда ему присваивается атрибут.

Во время обучения я передам m.parameters() экземпляру Optimizer, чтобы они могли быть обновлены.


Вопрос : Как предотвратить оптимизацию определенных параметров оптимизатором для встроенного модуля pytorch?

s = Sequential(
        nn.Linear(2,2),
        nn.Linear(2,3),   # I want this one's .weight and .bias to be constant
        nn.Linear(3,1)
    )
  • Можно ли сделать так, чтобы они не появлялись в s.parameters()?
  • Могу ли я сделать параметры доступными только для чтения, чтобы любые попытки изменения игнорировались?

1 Ответ

0 голосов
/ 03 мая 2019

Параметры можно сделать статическими, установив их атрибут requires_grad=False.

В моем примере:

params = list(s.parameters())  # .parameters() returns a generator
# Each linear layer has 2 parameters (.weight and .bias),
# Skipping first layer's parameters (indices 0, 1):
params[2].requires_grad = False
params[3].requires_grad = False

Когда для вычисления используется сочетание тензоров requires_grad=True и requires_grad=False, результат наследует requires_grad=True.

В соответствии с Документация по механике автограда PyTorch :

Если в операции, требующей градиента, есть один вход, для его вывода также потребуется градиент. И наоборот, только если все входные данные не требуют градиента, выходные данные также не требуют его. Вычисления в обратном направлении никогда не выполняются в подграфах, где всем тензорам не требуются градиенты.


Меня беспокоило, что если я отключу отслеживание градиента для среднего слоя, первый слой не получит градиенты с обратным распространением. Это было ошибочное понимание.

Edge Case : Если я отключу градиенты для всех параметров в модуле и попытаюсь обучить, оптимизатор вызовет исключение. Потому что нет единого тензора, к которому можно применить пропуск backward().

В этом крайнем случае я получал ошибки Я пытался протестировать параметры requires_grad=False для модуля с одним слоем nn.Linear. Это означало, что я отключил отслеживание для всех параметров, из-за чего оптимизатор жаловался.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...