Как избежать второго прохода в состязательной модели PyTorch? - PullRequest
1 голос
/ 05 мая 2020

Класс c Пример / учебник PyTorch для обучения GAN l oop показан здесь . Вы можете найти строки

# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)

, но, поскольку такой противник может стать довольно большим, это кажется совершенно ненужным делом. Чтобы быть более конкретным, у меня есть минимальный пример моего (в некотором смысле другого) состязательного обучения, где это имеет еще большее значение (см. Конец этого поста). По сути, у меня есть сеть X, которая выдает результат out1, и сеть Y, которая производит два члена потерь out2_x и out2_y. Каждый из этих терминов потерь следует использовать для обновления весов соответствующей модели. Итак, чтобы обновить веса X, мне нужно выполнить полный обратный проход через обе сети. Если я подсчитываю затраты времени выполнения с точки зрения оценок сети (вперед или назад), у меня теперь было 2 прямых и 2 обратных оценки сети. Чтобы обновить вторую сеть, я надеялся просто очистить предыдущие градиенты этой части и выполнить обратный проход до out1 и остановиться на этом. Это добавило бы одну дополнительную оценку обратной сети, и у меня есть все необходимые обновления градиента.

В вышеупомянутом руководстве GAN предлагается выполнить дополнительный прямой проход во второй сети, что даст в общей сложности 6 вместо 5 сетей. оценки, которые мне кажутся легко устранимыми и ненужными.

Однако моя идея как-то остановить обратный проход автографов раньше не кажется легко осуществимой. Я нашел решение, которое ведет себя так, как описано выше (см. Пример ниже), но мне интересно, не слишком ли оно дорого, поскольку установка require_grads на False для всех X весов означает, что PyTorch также должен выполнить некоторые дополнительные обновления графиков. чтобы выяснить, насколько глубоким обратным проходом должен быть go, и поскольку это происходит на стороне Python (верно?), это может быть потенциально довольно медленным для больших сетевых архитектур с множеством однотензорных операций ?! Кроме того, имея один единственный тензор узких мест out1, я надеялся получить более чистый код, указав autograd специально для остановки обратного прохода там, и теперь мне действительно любопытно, почему это кажется таким сложным для достижения в структуре.

Любые идеи очень приветствуются!

import torch

# Setting up a dummy graph
x1: torch.Tensor = torch.tensor(1.0, requires_grad=True)
x2: torch.Tensor = torch.tensor(3.0, requires_grad=True)

out1 = x1 * (x2 + 1)
y: torch.Tensor = torch.tensor(11.0, requires_grad=True)

out2_x = out1 * y + y ** 2
out2_y = out1 ** 2 * y ** 2

opt1 = torch.optim.SGD([x1, x2], lr=0.01)
opt2 = torch.optim.SGD([y], lr=0.01)
# Finished setting up a dummy graph

# showcasing some gradients
opt1.zero_grad()
opt2.zero_grad()
out2_x.backward(retain_graph=True)
print("grads wrt out2_x\t\t\t:", x1.grad, x2.grad, y.grad)
opt1.zero_grad()
opt2.zero_grad()
out2_y.backward(retain_graph=True)
print("grads wrt out2_y\t\t\t:", x1.grad, x2.grad, y.grad)

# actual logic for computing gradients
opt1.zero_grad()
opt2.zero_grad()
out2_x.backward(retain_graph=True)
# the following two lines work
x1.requires_grad = False
x2.requires_grad = False
# but I was hoping something like the following would work
#   * out1.requires_grad = False
#   * out1.is_leaf = True
#   * out1.detach() (this would require a second forward pass)
#   * out2_y.backward(tell it to stop at out1)
print("grads wrt out2_x\t\t\t:", x1.grad, x2.grad, y.grad)
opt2.zero_grad()
print("grads after zeroing y\t\t\t:", x1.grad, x2.grad, y.grad)
out2_y.backward()
print("grads after out2_y with no_grads x\t:", x1.grad, x2.grad, y.grad)
opt1.step()
opt2.step()
# x1.requires_grad = True
# x2.requires_grad = True

Вывод:

grads wrt out2_x            : tensor(44.) tensor(11.) tensor(26.)
grads wrt out2_y            : tensor(3872.) tensor(968.) tensor(352.)
grads wrt out2_x            : tensor(44.) tensor(11.) tensor(26.)
grads after zeroing y           : tensor(44.) tensor(11.) tensor(0.)
grads after out2_y with no_grads x  : tensor(44.) tensor(11.) tensor(352.)
...