PyTorch Hooks - Как условно обнулить элементы градиента на основе пути? - PullRequest
0 голосов
/ 23 февраля 2020

Мне трудно понять, как реализовать то, что я хочу, в PyTorch: обратное распространение условного градиента пути.

Для простоты предположим, что у меня есть данные с формой (batch size, input_dimension), и у меня есть простая сеть который выводит скалярную сумму двух аффинных преобразований ввода, т.е.

linear1 = nn.Linear(in_features=input_dimension, out_features=1)
linear2 = nn.Linear(in_features=input_dimension, out_features=1)
y = linear1(x) + linear2(x)
loss = torch.mean((y - y_target) ** 2)

Во время backprop, я хотел бы обновить параметры linear1, используя только элементы в пакете, где $ y <0 $ и обновить параметры <code>linear2, используя только элементы в пакете, где $ y> 0 $.

Как я могу это реализовать?

Я пробовал register_backward_hook, но если я правильно понимаю функциональность, ко времени вызова зарегистрированной функции градиент ошибки по параметрам уже рассчитан. Я попытался register_hook, но это не позволяет мне условно маскировать градиент dL / dy в зависимости от того, какой линейный слой с backward() называется следующим.

Редактировать 1: Я думаю, что самое простое объяснение проблемы является следующим: Предположим, что я вычисляю градиент относительно тензора y, т.е. dL / dy. Как я могу создать две измененные версии dL / dy и направить их к различным подграфам вычислительного графа?

...