Оптимизатор генератора также обучает дискриминатор? - PullRequest
1 голос
/ 21 марта 2020

При изучении GAN я заметил, что примеры кода демонстрируют этот паттерн:

Дискриминатор обучен так:

d_optim.zero_grad()

real_pred = d(real_batch)
d_loss = d_loss_fn(real_pred, torch.ones(real_batch_size, 1))
d_loss.backward()

fake_pred = d(g(noise_batch).detach())
d_loss = d_loss_fn(fake_pred, torch.zeros(noise_batch_size, 1))
d_loss.backward()

d_optim.step()

Генератор обучен так:

g_optim.zero_grad()

fake_pred = d(g(noise_batch))
g_loss = g_loss_fn(fake_pred, torch.ones(noise_batch_size, 1))
g_loss.backward()

g_optim.step()

Упоминается, что d(g(noise_batch).detach()) написано для дискриминатора вместо d(g(noise_batch)), чтобы помешать d_optim.step() обучить g, но ничего не сказано о d(g(noise_batch)) для генератора; g_optim.step() будет также тренироваться d?

На самом деле, почему мы d(g(noise_batch).detach()), если, например, d_optim = torch.optim.SGD(d.parameters(), lr=0.001)? Разве это не указывает, что d.parameters(), а также g.parameters() должны быть обновлены?

1 Ответ

2 голосов
/ 21 марта 2020

TLDR: optimizer обновит только параметры, указанные для него, тогда как вызов backward() вычисляет градиенты для всех переменных в графе вычислений. Таким образом, полезно detach() переменных, для которых вычисление градиента не требуется в данный момент.

Я полагаю, что ответ заключается в том, как все реализовано в PyTorch.

  • tensor.detach() создает тензор, который совместно использует хранилище с tensor, для которого не требуется grad. Таким образом, вы фактически обрезаете график вычислений. То есть выполнение fake_pred = d(g(noise_batch).detach()) отсоединит (обрезает) график вычислений генератора.
  • Когда вы вызываете backward() для потери, градиенты рассчитываются для всего графика вычислений (независимо от того, использует его оптимизатор или нет). Таким образом, отсечение части генератора позволит избежать вычислений градиента для весов генератора (поскольку они не требуются).
  • Кроме того, только параметры, переданные конкретному optimizer, обновляются при вызове optimizer.step(). Таким образом, g_optim будет оптимизировать только параметры, передаваемые ему (вы не указываете явно, какие параметры передаются в g_optim). Аналогично, d_optim будет обновлять только d.parameters(), поскольку вы явно указываете это.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...