SRGAN был реализован с использованием PyTorch.
Предварительная подготовка генератора проводилась в 100 раз, а поезд SRGAN - в 200 раз.
Код представляет собой комбинацию существующих кодов GitHub.
Для потери содержимого использовалась MSELoss () в PyTorch, а BCELoss () в PyTorch использовалась для состязательной потери.
Когда я запускаю код, LossD сходится к 0, и LossG колеблется вокруг определенного значения. Поэтому я прекратил тренироваться, потому что думал, что это уже не тренировка.
Если обучение будет 1е5, как в статье, изменится ли результат? Или это функция потерь?
Ниже приведен учебный код SRGAN.
print('Adversarial training')
for epoch in range(NUM_EPOCHS):
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
# train_bar = tqdm(train_loader)
for data, target in train_bar:
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
target_real = Variable(torch.ones(batch_size, 1))
target_fake = Variable(torch.zeros(batch_size, 1))
if torch.cuda.is_available():
target_real = target_real.cuda()
target_fake = target_fake.cuda()
real_img = Variable(target)
z = Variable(data)
# Generate real and fake inputs
if torch.cuda.is_available():
inputsD_real = real_img.cuda()
inputsD_fake = netG(z.cuda())
else:
inputsD_real = real_img
inputsD_fake = netG(z)
######### Train discriminator #########
netD.zero_grad()
# With real data
outputs = netD(inputsD_real)
D_real = outputs.data.mean()
lossD_real = adversarial_criterion(outputs, target_real)
# With fake data
outputs = netD(inputsD_fake.detach()) # Don't need to compute gradients wrt weights of netG (for efficiency)
D_fake = outputs.data.mean()
lossD_fake = adversarial_criterion(outputs, target_fake)
lossD_total = lossD_real + lossD_fake
lossD_total.backward()
# Update discriminator weights
optimizerD.step()
######### Train generator #########
netG.zero_grad()
real_features = Variable(feature_extractor(inputsD_real).data)
fake_features = feature_extractor(inputsD_fake)
lossG_vgg19 = content_criterion(fake_features, real_features)
lossG_adversarial = adversarial_criterion(netD(inputsD_fake).detach(), target_real)
lossG_mse = content_criterion(inputsD_fake, inputsD_real)
lossG_total = lossG_mse + 2e-6 * lossG_vgg19 + 0.001 * lossG_adversarial
lossG_total.backward()
# Update generator weights
optimizerG.step()