GAN в Pytorch: дискриминатор побеждает, что-то неправильно, неправильная настройка - PullRequest
0 голосов
/ 08 июля 2019

Пытаясь реализовать GAN в Pytorch, я получаю результат, что генератор ничего не изучает (или плохо изучает), и дискриминатор хорошо воспроизводит (примерно 95% правильно). Я думаю, что-то не так с настройкой обратного распространения. Проект большой, поэтому я не публикую его полностью, только важное место внутри тренинга:

loss = torch.nn.CrossEntropyLoss()

...

for epoch in range(epochs):

for start_index in range(0,len(x_train), batch_size):


optimizer.zero_grad() 

    x_batch = x_train[start_index : start_index+batch_size]

    y_batch = y_train[start_index : start_index+batch_size]

    output = nnet.forward(x_batch)

    real_loss_value = loss(output, y_batch)

    x_gen, y_gen_false_real = ngen.rnd_batch(x_batch.size(0))






    x_gen = x_gen.view(-1,1,28,28)  

    y_gen_true_fake = y_gen_false_real + 10

    gen_output = nnet.forward(x_gen)


    gen_optimizer.zero_grad()

    gen_output = nnet.forward(x_gen)

    gen_success_loss =   loss(gen_output, y_gen_false_real)

    gen_success_loss.backward()      


    gen_optimizer.step()        


    # Measure discriminator's ability to classify real from generated samples
    # if fake recognized, the output will be 10-19

    gen_output = nnet.forward(x_gen.detach())

    fake_loss_value = loss(gen_output, y_gen_true_fake)

    d_loss = (real_loss_value + fake_loss_value) / 2

    d_loss.backward()

    optimizer.step()

    optimizer.zero_grad() 

Это не похоже на туториал для примера здесь https://github.com/eriklindernoren/PyTorch-GAN но я думаю, что следующие должны работать: Дискриминатор имеет выходной сигнал 20 флагов: первые 0-9 означают реальные цифры, последние 10-19 распознаются как поддельные генератора. соответствующие выводы сделаны в строке

y_gen_true_fake = y_gen_false_real + 10

С потерями d_loss = (real_loss_value + fake_loss_value) / 2 дискриминатор прекрасно работает даже после 1 эпохи, но генератор с gen_success_loss = loss(gen_output, y_gen_false_real) ничего не наклоняет и продолжает просто производить шум. Я предполагаю, что что-то в обратном распространении вызова идет не так, как я, я не совсем понимаю это множественные обратные вызовы. Вы можете помочь мне, пожалуйста?

...