DCGAN для набора данных RGB странные результаты - PullRequest
2 голосов
/ 25 марта 2020

Я внедряю сеть DCGAN в LibTorch / Pytorch. Я следую официальному примеру в https://github.com/pytorch/examples/blob/master/cpp/dcgan/dcgan.cpp. Единственные различия между моей проблемой и примером:

  • Мой набор данных состоит из изображений RGB (набор данных CelebA), в то время как один из примера - черно-белый (MNIST)
  • Размеры моих картинок - 64х64, в то время как картинки MNIST - 28х28

Здесь указан мой код:

#include <torch/torch.h>

#include <cmath>
#include <cstdio>
#include <iostream>
#include "CustomDataset.h"
#include "parameters.h"

// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;

// The batch size for training.
const int64_t kBatchSize = 64;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;

// Where to find the MNIST dataset.
const char* kDataFolder = "./data";

// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 20;

// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;


// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;

using namespace torch;

struct DCGANGeneratorImpl : nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
            : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                            .bias(false)),
              batch_norm1(256),
              conv2(nn::ConvTranspose2dOptions(256, 128, 4)
                            .stride(2)
                            .padding(1)
                            .bias(false)),
              batch_norm2(128),
              conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                            .stride(2)
                            .padding(1)
                            .bias(false)),
              batch_norm3(64),
              conv4(nn::ConvTranspose2dOptions(64, 32, 4)
                            .stride(2)
                            .padding(1)
                            .bias(false)),
              batch_norm4(32),
              conv5(nn::ConvTranspose2dOptions(32, 3, 4)
                            .stride(2)
                            .padding(1)
                            .bias(false))

    {
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("conv5", conv5);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
        register_module("batch_norm4", batch_norm4);

    }

    torch::Tensor forward(torch::Tensor x)
    {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::relu(batch_norm4(conv4(x)));
        x = torch::tanh(conv5(x));
        return x;
    }


    nn::ConvTranspose2d conv1, conv2, conv3, conv4, conv5;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3, batch_norm4;
};

TORCH_MODULE(DCGANGenerator);

int main(int argc, const char* argv[]) {
    torch::manual_seed(1);

    // Create the device we pass around based on whether CUDA is available.
    torch::Device device(torch::kCPU);
    if (torch::cuda::is_available()) {
        std::cout << "CUDA is available! Training on GPU." << std::endl;
        device = torch::Device(torch::kCUDA);
    }

    DCGANGenerator generator(kNoiseSize);
    generator->to(device);

    nn::Sequential discriminator(
            // Layer 1
            nn::Conv2d(
                    nn::Conv2dOptions(3, 64, 4).stride(2).padding(1).bias(false)),
            nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
            //output is 32x32
            // Layer 2
            nn::Conv2d(
                    nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
            nn::BatchNorm2d(128),
            nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
            //output is 16x16
            // Layer 3
            nn::Conv2d(
                    nn::Conv2dOptions(128, 64, 4).stride(2).padding(1).bias(false)),
            nn::BatchNorm2d(64),
            nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
            //output is 8x8
            // Layer 4
            nn::Conv2d(
                    nn::Conv2dOptions(64, 32, 5).stride(1).padding(0).bias(false)),
            nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
            // output is 4x4
            // Layer 5
            nn::Conv2d(
                    nn::Conv2dOptions(32, 1, 4).stride(1).padding(0).bias(false)),
            nn::Sigmoid());
    discriminator->to(device);

    // Where all my pictures are;
    std::string file_location{"dataset/img_align_celeba/*.jpg"};
    auto dataset = CustomDataset(file_location).map(data::transforms::Stack<>());

    const int64_t batches_per_epoch =
            std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

    auto data_loader = torch::data::make_data_loader(
            std::move(dataset),
            torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));


    torch::optim::Adam generator_optimizer(
            generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
    torch::optim::Adam discriminator_optimizer(
            discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));


    int64_t checkpoint_counter = 1;
    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        for (torch::data::Example<>& batch : *data_loader) {
            // Train discriminator with real images.
            discriminator->zero_grad();
            torch::Tensor real_images = batch.data.to(device);
            torch::Tensor real_labels =
                    torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
            torch::Tensor real_output = discriminator->forward(real_images);
            torch::Tensor d_loss_real =
                    torch::binary_cross_entropy(real_output, real_labels);
            d_loss_real.backward();

            // Train discriminator with fake images.
            torch::Tensor noise =
                    torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
            torch::Tensor fake_images = generator->forward(noise);
            torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
            torch::Tensor fake_output = discriminator->forward(fake_images.detach());
            torch::Tensor d_loss_fake =
                    torch::binary_cross_entropy(fake_output, fake_labels);
            d_loss_fake.backward();

            torch::Tensor d_loss = d_loss_real + d_loss_fake;
            discriminator_optimizer.step();

            // Train generator.
            generator->zero_grad();
            fake_labels.fill_(1);
            fake_output = discriminator->forward(fake_images);
            torch::Tensor g_loss =
                    torch::binary_cross_entropy(fake_output, fake_labels);
            g_loss.backward();
            generator_optimizer.step();
            batch_index++;      

            if (batch_index % kCheckpointEvery == 0) {
                // Checkpoint the model and optimizer state.
                torch::save(generator, "generator-checkpoint.pt");
                torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
                torch::save(discriminator, "discriminator-checkpoint.pt");
                torch::save(
                        discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
                // Sample the generator and save the images.
                torch::Tensor samples = generator->forward(torch::randn(
                        {kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
                torch::save(
                        samples,
                        torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
                std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
            }
        }
    }

    std::cout << "Training complete!" << std::endl;
}

Время от времени я сохраняю мини-пакеты и отображаю результат ввод шума через генератор. Проблема в том, что в примере MNIST результаты верны, но в моем случае для каждого выходного изображения я вижу 9 небольших изображений с лицами вместо одного (см. Прилагаемое изображение).

enter image description here

Как получается, что генератор выводит правильную форму, но с 9 почти идентичными гранями вместо одной?

...