Реализация Вассерштейна ГАН в pytorch. Как осуществить потерю? - PullRequest
0 голосов
/ 23 апреля 2020

В настоящее время я работаю над проектом в pytorch на Wassertein GAN (https://arxiv.org/pdf/1701.07875.pdf).

В Wasserstain GAN новая целевая функция определяется с использованием расстояния wasserstein следующим образом: enter image description here

Что приводит к следующим алгоритмам обучения ГАН:

enter image description here

Мой вопрос:

При реализации строк 5 и 6 алгоритма в pytorch я должен умножать свою потерю -1? Как и в моем коде (я использую RMSprop в качестве моего оптимизатора как для генератора, так и для критического анализа c):

############################
  # (1) Update D network: maximize (D(x)) + (D(G(x)))
  ###########################
  for n in range(n_critic):


    D.zero_grad()

    real_cpu = data[0].to(device)
    b_size = real_cpu.size(0)


    output = D(real_cpu)

    #errD_real = -criterion(output, label) #DCGAN
    errD_real = torch.mean(output)

    # Calculate gradients for D in backward pass
    errD_real.backward()
    D_x = output.mean().item()

    ## Train with all-fake batch
    # Generate batch of latent vectors
    noise = torch.randn(b_size, 100, device=device) #Careful here we changed shape of input (original : torch.randn(4, 100, 1, 1, device=device))

    # Generate fake image batch with G
    fake = G(noise)

    # Classify all fake batch with D
    output = D(fake.detach())


    # Calculate D's loss on the all-fake batch
    errD_fake = torch.mean(output)

    # Calculate the gradients for this batch
    errD_fake.backward()
    D_G_z1 = output.mean().item()

    # Add the gradients from the all-real and all-fake batches
    errD = -(errD_real - errD_fake)

    # Update D
    optimizerD.step()

    #Clipping weights
    for p in D.parameters():
      p.data.clamp_(-0.01, 0.01)

Как видите, я выполняю операцию errD = - (errD_real - errD_fake), с errD_real и errD_fake, соответственно, являются средними значениями прогнозов критических значений c для реальных и поддельных выборок.

Насколько я понимаю, RMSprop следует оптимизировать веса критических значений c следующим образом:

w <- w - альфа * градиент (w) </p>

(альфа - скорость обучения, деленная на квадрат root взвешенного скользящего среднего квадрата градиента)

Поскольку Задача оптимизации требует, чтобы "go" было в том же направлении, что и градиент, и это необходимо для умножения градиента (w) на -1, прежде чем оптимизировать веса.

Как вы думаете, мои рассуждения да?

Программа работает, но мои результаты довольно плохие.

Я следую той же логике c для весов генератора, но на этот раз для go в напротив г наклон градиента:

      ############################
  # (2) Update G network: minimize -D(G(x))
  ###########################
  G.zero_grad()

  noise = torch.randn(b_size, 100, device=device)
  fake = G(noise)

  #label.fill_(fake_label)  # fake labels are real for generator cost

  # Since we just updated D, perform another forward pass of all-fake batch through D
  output = D(fake).view(-1)

  # Calculate G's loss based on this output

  #errG = criterion(output, label) #DCGAN
  errG = -torch.mean(output)

  # Calculate gradients for G
  errG.backward()
  D_G_z2 = output.mean().item()

  # Update G
  optimizerG.step()

Извините за длинный вопрос, я попытался объяснить свои сомнения как можно яснее. Спасибо всем.

1 Ответ

0 голосов
/ 24 апреля 2020

Я заметил некоторые ошибки в реализации вашего протокола обучения дискриминатора. Вы вызываете свои обратные функции дважды, причем потеря действительных и поддельных значений происходит с обратным распространением на разных временных шагах.

Технически реализация с использованием этой схемы возможна, но крайне нечитаема. С вашим errD_real произошла ошибка, из-за которой ваш результат будет положительным, а не отрицательным как оптимальный D(G(z))>0, поэтому вы штрафуете его за правильность. В целом ваша модель сходится просто путем прогнозирования D(x)<0 для всех входных данных.

Чтобы это исправить, не называйте ваш errD_readl.backward() или errD_fake.backward(). Простое использование errD.backward() после определения errD будет работать отлично. В противном случае ваш генератор кажется правильным.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...