Ошибка отладки GAN - PullRequest
       44

Ошибка отладки GAN

0 голосов
/ 08 июня 2018

Создание GAN для генерации изображений.Изображения имеют 3 цветовых канала, 96 x 96.

Все изображения, сгенерированные генератором в начале, являются черными, что является статистически крайне маловероятной проблемой.

Такжепотери для обеих сетей не улучшаются.

Я разместил весь код ниже и прокомментировал его, чтобы его было легко прочитать.Я впервые строю GAN, и я новичок в Pytorch, поэтому любая помощь очень ценится!

Спасибо.

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.autograd import Variable

import numpy as np
import os
import cv2
from collections import deque

# training params
batch_size = 100
epochs = 1000

# loss function
loss_fx = torch.nn.BCELoss()

# processing images
X = deque()
for img in os.listdir('pokemon_images'):
    if img.endswith('.png'):
        pokemon_image = cv2.imread(r'./pokemon_images/{}'.format(img))
        if pokemon_image.shape != (96, 96, 3):
            pass
        else:
            X.append(pokemon_image)

# data loader for processing in batches
data_loader = DataLoader(X, batch_size=batch_size)

# covert output vectors to images if flag is true, else input images to vectors
def images_to_vectors(data, reverse=False):
    if reverse:
        return data.view(data.size(0), 3, 96, 96)
    else:
        return data.view(data.size(0), 27648)

# Generator model
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        n_features = 1000
        n_out = 27648

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 1024),
                torch.nn.ReLU(),
                torch.nn.Linear(1024, n_out),
                torch.nn.Tanh()
        )


    def forward(self, x):
        img = self.model(x)
        return img

    def noise(self, s):
       x = Variable(torch.randn(s, 1000))
       return x


# Discriminator model
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        n_features = 27648
        n_out = 1

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, n_out),
                torch.nn.Sigmoid()
        )


    def forward(self, img):
        output = self.model(img)
        return output


# discriminator training
def train_discriminator(discriminator, optimizer, real_data, fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()

    # train on real
    # get prediction
    pred_real = discriminator(real_data)

    # calculate loss
    error_real = loss_fx(pred_real, Variable(torch.ones(N, 1)))

    # calculate gradients
    error_real.backward()

    # train on fake
    # get prediction
    pred_fake = discriminator(fake_data)

    # calculate loss
    error_fake = loss_fx(pred_fake, Variable(torch.ones(N, 0)))

    # calculate gradients
    error_fake.backward()

    # update weights
    optimizer.step()

    return error_real + error_fake, pred_real, pred_fake


# generator training
def train_generator(generator, optimizer, fake_data):
    N = fake_data.size(0)

    # zero gradients
    optimizer.zero_grad()

    # get prediction
    pred = discriminator(generator(fake_data))

    # get loss
    error = loss_fx(pred, Variable(torch.ones(N, 0)))

    # compute gradients
    error.backward()

    # update weights
    optimizer.step()

    return error


# Instance of generator and discriminator
generator = Generator()
discriminator = Discriminator()

# optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

# training loop
for epoch in range(epochs):
     for n_batch, batch in enumerate(data_loader, 0):
         N = batch.size(0)

         # Train Discriminator

         # REAL
         real_images = Variable(images_to_vectors(batch)).float()

         # FAKE
         fake_images = generator(generator.noise(N)).detach()

         # TRAIN
         d_error, d_pred_real, d_pred_fake = train_discriminator(
                 discriminator,
                 d_optimizer,
                 real_images,
                 fake_images
         )

         # Train Generator

         # generate noise
         fake_data = generator.noise(N)

         # get error based on discriminator
         g_error = train_generator(generator, g_optimizer, fake_data)

         # convert generator output to image and preprocess to show
         test_img = np.array(images_to_vectors(generator(fake_data), reverse=True).detach())
         test_img = test_img[0, :, :, :]
         test_img = test_img[..., ::-1]

         # show example of generated image
         cv2.imshow('GENERATED', test_img[0])
         if cv2.waitKey(1) & 0xFF == ord('q'):
             break

     print('EPOCH: {0}, D error: {1}, G error: {2}'.format(epoch, d_error, g_error))


cv2.destroyAllWindows()

# save weights
# torch.save('weights.pth')

1 Ответ

0 голосов
/ 08 июня 2018

Нельзя легко отладить тренировку без данных и т. Д., Но возможная проблема заключается в том, что последний слой вашего генератора - Tanh(), что означает выходные значения от -1 до 1.Возможно, вы захотите:

  1. , чтобы ваши реальные изображения были нормализованы к тому же диапазону, например, в train_discriminator():

    # train on real
    pred_real = discriminator(real_data * 2. - 1.) # supposing real_data in [0, 1]
    
  2. -нормализуйте сгенерированные данные до [0, 1] перед визуализацией / использованием.

    # convert generator output to image and preprocess to show
    test_img = np.array(
        images_to_vectors(generator(fake_data), reverse=True).detach())
    test_img = test_img[0, :, :, :]
    test_img = test_img[..., ::-1]
    test_img = (test_img + 1.) / 2.
    
...