Все большие и большие положительные потери WGAN-GP - PullRequest
0 голосов
/ 26 ноября 2018

Я исследую использование Вассерштейна GAN с градиентным штрафом в PyTorch, но постоянно получаю большие положительные потери генератора, которые увеличиваются в течение эпох.Я сильно заимствую из реализации Caogang , но я использую дискриминатор и потери генератора, используемые в этой реализации , потому что я получаю Invalid gradient at index 0 - expected shape[] but got [1], если я пытаюсь вызвать .backward() с помощью *Аргументы 1007 * и mone, используемые в реализации Caogang.

Я тренируюсь на расширенном наборе данных WikiArt (> 400k 64x64 изображений) и CIFAR-10, и получил обычный WGAN (с отсечкой веса доработа) [т.е. он производит проходимые изображения после 25 эпох], несмотря на тот факт, что потери D и G колеблются около 3 [я рассчитываю их, используя torch.mean(D_real) и т. д.] для всех эпох.Однако в версии WGAN-GP потери генератора резко возрастают как в наборах данных WikiArt, так и в наборах CIFAR-10, и полностью не генерируют ничего, кроме шума в WikiArt.

Вот пример потери после 25эпохи в CIFAR-10: WGAN-GP loss

Я не использую никаких приемов, таких как одностороннее сглаживание меток, и я тренируюсь со скоростью обучения по умолчанию 0,001, оптимизатором Адама иЯ тренирую дискриминатор 5 раз для каждого обновления генератора.Почему происходит такое сумасшедшее поведение при потере, и почему обычный WGAN с ограничением веса все еще «работает» на WikiArt, но WGANGP полностью терпит неудачу?

Это происходит независимо от структуры, являются ли G и D DCGAN или когдаиспользуя этот модифицированный DCGAN, Creative Adversarial Network , который требует, чтобы D был в состоянии классифицировать изображения, а G. генерировать неоднозначные изображения.

Ниже приведена соответствующая часть моего текущего train метода:

self.generator = Can64Generator(self.z_noise, self.channels, self.num_gen_filters).to(self.device)
self.discriminator =WCan64Discriminator(self.channels,self.y_dim, self.num_disc_filters).to(self.device)
style_criterion = nn.CrossEntropyLoss()

self.disc_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))
self.gen_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))


while i < len(dataloader):
            j = 0
            disc_loss_epoch = []
            gen_loss_epoch = []
            if self.type == "can":
                disc_class_loss_epoch = []
                gen_class_loss_epoch = []

            if self.gradient_penalty == False:
                # critic training methodology in official WGAN implementation
                if gen_iterations < 25 or (gen_iterations % 500 == 0):
                    disc_iters = 100
            else:
                disc_iters = self.disc_iterations

            while j < disc_iters and (i < len(dataloader)):
                # if using wgan with weight clipping
                if self.gradient_penalty == False:
                    # Train Discriminator
                    for param in self.discriminator.parameters():
                        param.data.clamp_(self.lower_clamp,self.upper_clamp)


                for param in self.discriminator.parameters():
                    param.requires_grad_(True)

                j+=1
                i+=1
                data = data_iterator.next()
                self.discriminator.zero_grad()
                real_images, image_labels = data
                # image labels are the the image's classes (e.g. Impressionism)
                real_images = real_images.to(self.device) 
                batch_size = real_images.size(0)
                real_image_labels = torch.LongTensor(batch_size).to(self.device)
                real_image_labels.copy_(image_labels)

                labels = torch.full((batch_size,),real_label,device=self.device)

                if self.type == 'can':
                    predicted_output_real, predicted_styles_real = self.discriminator(real_images.detach())
                    predicted_styles_real = predicted_styles_real.to(self.device)
                    disc_class_loss = style_criterion(predicted_styles_real,real_image_labels)
                    disc_class_loss.backward(retain_graph=True)

                else:
                    predicted_output_real = self.discriminator(real_images.detach())

                disc_loss_real = -torch.mean(predicted_output_real)


                # fake

                noise = torch.randn(batch_size,self.z_noise,1,1,device=self.device)
                with torch.no_grad():
                    noise_g = noise.detach()
                fake_images = self.generator(noise_g)
                labels.fill_(fake_label)

                if self.type == 'can':
                    predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)

                else:
                    predicted_output_fake = self.discriminator(fake_images)



                disc_gen_z_1 = predicted_output_fake.mean().item()

                disc_loss_fake = torch.mean(predicted_output_fake)


                #via https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/WGAN_GP.py
                if self.gradient_penalty:
                    # gradient penalty
                    alpha = torch.rand((real_images.size()[0], 1, 1, 1)).to(self.device) 
                    x_hat = alpha * real_images.data + (1 - alpha) * fake_images.data
                    x_hat.requires_grad_(True)
                    if self.type == 'can':
                        pred_hat, _ = self.discriminator(x_hat)
                    else:
                        pred_hat = self.discriminator(x_hat)
                    gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(self.device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]

                    gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
                    disc_loss = disc_loss_fake + disc_loss_real + gradient_penalty
                else:
                    disc_loss  =  disc_loss_fake  + disc_loss_real


                if self.type == 'can':
                    disc_loss += disc_class_loss.mean()

                disc_x = disc_loss.mean().item()
                disc_loss.backward(retain_graph=True)
                self.disc_optimizer.step()



            # train generator
            for param in self.discriminator.parameters():
                param.requires_grad_(False)

            self.generator.zero_grad()
            labels.fill_(real_label)

            if self.type == 'can':
                predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)
                predicted_styles_fake = predicted_styles_fake.to(self.device)

            else:
                predicted_output_fake = self.discriminator(fake_images)

            gen_loss = -torch.mean(predicted_output_fake)
            disc_gen_z_2 = gen_loss.mean().item()

            if self.type == 'can':
                fake_batch_labels = 1.0/self.y_dim * torch.ones_like(predicted_styles_fake)
                fake_batch_labels = torch.mean(fake_batch_labels,1).long().to(self.device)
                gen_class_loss = style_criterion(predicted_styles_fake,fake_batch_labels)
                gen_class_loss.backward(retain_graph=True)
                gen_loss += gen_class_loss.mean()

            gen_loss.backward()
            gen_iterations += 1

Это код для генератора (DCGAN):

class Can64Generator(nn.Module):
def __init__(self, z_noise, channels, num_gen_filters):
    super(Can64Generator,self).__init__()
    self.ngpu = 1
    self.main = nn.Sequential(
    nn.ConvTranspose2d(z_noise, num_gen_filters * 16, 4, 1, 0, bias=False),
    nn.BatchNorm2d(num_gen_filters * 16),
    nn.ReLU(True),
    nn.ConvTranspose2d(num_gen_filters * 16, num_gen_filters * 4, 4, 2, 1, bias=False),
    nn.BatchNorm2d(num_gen_filters * 4),
    nn.ReLU(True),
    nn.ConvTranspose2d(num_gen_filters * 4, num_gen_filters * 2, 4, 2, 1, bias=False),
    nn.BatchNorm2d(num_gen_filters * 2),
    nn.ReLU(True),
    nn.ConvTranspose2d(num_gen_filters * 2, num_gen_filters, 4, 2, 1, bias=False),
    nn.BatchNorm2d(num_gen_filters),
    nn.ReLU(True),
    nn.ConvTranspose2d(num_gen_filters, 3, 4, 2, 1, bias=False),
    nn.Tanh()
    )
def forward(self, inp):
    output = self.main(inp)
    return output

И это (текущий) дискриминатор CAN, который имеет дополнительные слои для классификации по стилю (классу изображения)):

class Can64Discriminator(nn.Module):

def __init__(self, channels,y_dim, num_disc_filters):
        super(Can64Discriminator, self).__init__()
        self.ngpu = 1
        self.conv = nn.Sequential(
                nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
                nn.BatchNorm2d(num_disc_filters),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(num_disc_filters * 2),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(num_disc_filters * 4),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(num_disc_filters * 8),
                nn.LeakyReLU(0.2, inplace=True),

            )
        # was this
        #self.final_conv = nn.Conv2d(num_disc_filters * 8, num_disc_filters * 8, 4, 2, 1, bias=False)

        self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)

        # no bn and lrelu needed
        self.sig = nn.Sigmoid()
        self.fc = nn.Sequential() 
        self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16))
        self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8))
        self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim))
        self.fc.add_module('softmax',nn.Softmax(dim=1))

def forward(self, inp):
    x = self.conv(inp)
    x = x.view(x.size(0),-1) 
    real_out = self.sig(self.real_fake_head(x))
    real_out = real_out.view(-1,1).squeeze(1)
    style = self.fc(x) 
    #style = torch.mean(style,1) # CrossEntropyLoss requires input be (N,C)
    return real_out,style

Единственная разница между версией WGANGP и версией WGAN моей GAN заключается в том, что версия WGAN использует RMSprop с lr=0.00005 и обрезает вес дискриминатора, как указано в документе WGAN..

Что может быть причиной этого?Я хотел бы внести как можно меньше изменений, так как я хочу сравнивать только функции потерь.Та же проблема встречается даже при использовании неизмененного дискриминатора DCGAN на CIFAR-10.Я сталкиваюсь с этим, возможно, потому что я тренируюсь в настоящее время только для 25 эпох, или есть другая причина?Интересно, что мой GAN также не может генерировать ничего, кроме шума, при использовании LSGAN (nn.MSELoss()).

Заранее спасибо!

1 Ответ

0 голосов
/ 26 ноября 2018

Пакетная нормализация в дискриминаторе разрывает ГАС Вассерштейна с градиентным штрафом.Сами авторы выступают за использование нормализации слоев, но это ясно написано жирным шрифтом в их статье (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf). Трудно сказать, есть ли другие ошибки в вашем коде, но я призываю вас внимательно прочитатьDCGAN и статья Вассерштейна GAN действительно делают заметки о гиперпараметрах. Неправильное их определение действительно снижает производительность GAN, а поиск гиперпараметров довольно быстро обходится дорого.выводите изображения. Вместо этого используйте изменение размера изображения. Для подробного объяснения этого явления я могу порекомендовать следующий ресурс (https://distill.pub/2016/deconv-checkerboard/).

...