Пытаясь реализовать GAN в Pytorch, я получаю результат, что генератор ничего не изучает (или плохо изучает), и дискриминатор хорошо воспроизводит (примерно 95% правильно). Я думаю, что-то не так с настройкой обратного распространения.
Проект большой, поэтому я не публикую его полностью, только важное место внутри тренинга:
loss = torch.nn.CrossEntropyLoss()
...
for epoch in range(epochs):
for start_index in range(0,len(x_train), batch_size):
optimizer.zero_grad()
x_batch = x_train[start_index : start_index+batch_size]
y_batch = y_train[start_index : start_index+batch_size]
output = nnet.forward(x_batch)
real_loss_value = loss(output, y_batch)
x_gen, y_gen_false_real = ngen.rnd_batch(x_batch.size(0))
x_gen = x_gen.view(-1,1,28,28)
y_gen_true_fake = y_gen_false_real + 10
gen_output = nnet.forward(x_gen)
gen_optimizer.zero_grad()
gen_output = nnet.forward(x_gen)
gen_success_loss = loss(gen_output, y_gen_false_real)
gen_success_loss.backward()
gen_optimizer.step()
# Measure discriminator's ability to classify real from generated samples
# if fake recognized, the output will be 10-19
gen_output = nnet.forward(x_gen.detach())
fake_loss_value = loss(gen_output, y_gen_true_fake)
d_loss = (real_loss_value + fake_loss_value) / 2
d_loss.backward()
optimizer.step()
optimizer.zero_grad()
Это не похоже на туториал для примера здесь https://github.com/eriklindernoren/PyTorch-GAN
но я думаю, что следующие должны работать:
Дискриминатор имеет выходной сигнал 20 флагов: первые 0-9 означают реальные цифры, последние 10-19 распознаются как поддельные генератора. соответствующие выводы сделаны в строке
y_gen_true_fake = y_gen_false_real + 10
С потерями d_loss = (real_loss_value + fake_loss_value) / 2
дискриминатор прекрасно работает даже после 1 эпохи, но генератор с gen_success_loss = loss(gen_output, y_gen_false_real)
ничего не наклоняет и продолжает просто производить шум. Я предполагаю, что что-то в обратном распространении вызова идет не так, как я, я не совсем понимаю это множественные обратные вызовы. Вы можете помочь мне, пожалуйста?