Почему моя функция потерь генератора увеличивается с итерациями - PullRequest
2 голосов
/ 23 марта 2019

Я пытаюсь обучить DC-GAN на наборе данных CIFAR-10.Я использую двоичную перекрестную энтропию как функцию потерь для дискриминатора и генератора (с добавлением необучаемого дискриминатора).Если я тренируюсь с использованием Adam Optimizer, GAN тренируется нормально.Но если я заменим оптимизатор на SGD, обучение будет бесполезным.Точность генератора начинается с некоторой более высокой точки и с итерациями достигает 0 и остается там.Точность дискриминатора начинается с некоторой нижней точки и достигает где-то около 0,5 (ожидаемо, верно?).Особенность в том, что функция потерь генератора увеличивается с итерациями.Я, хотя может быть, шаг слишком высокЯ пытался изменить размер шага.Я пытался использовать импульс с SGD.Во всех этих случаях генератор может уменьшаться или не уменьшаться в начале, но затем точно увеличивается.Итак, я думаю, что в моей модели что-то не так.Я знаю, что тренировать Deep Models сложно, а GAN еще больше, но должна быть какая-то причина / эвристика в отношении того, почему это происходит.Любые входы в ценится.Я новичок в области нейронных сетей, глубокого обучения и, следовательно, также новичок в GAN.

Вот мой код: Cifar10Models.py

from keras import Sequential
from keras.initializers import TruncatedNormal
from keras.layers import Activation, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Flatten, LeakyReLU, Reshape
from keras.optimizers import SGD


class DcGan:
    def __init__(self, print_model_summary: bool = False):
        self.generator_model = None
        self.discriminator_model = None
        self.concatenated_model = None
        self.print_model_summary = print_model_summary

    def build_generator_model(self):
        if self.generator_model:
            return self.generator_model

        self.generator_model = Sequential()
        self.generator_model.add(Dense(4 * 4 * 512, input_dim=100,
                                       kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))
        self.generator_model.add(Reshape((4, 4, 512)))

        self.generator_model.add(Conv2DTranspose(256, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2DTranspose(128, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2DTranspose(64, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2D(3, 3, padding='same',
                                        kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(Activation('tanh'))

        if self.print_model_summary:
            self.generator_model.summary()

        return self.generator_model

    def build_discriminator_model(self):
        if self.discriminator_model:
            return self.discriminator_model

        self.discriminator_model = Sequential()
        self.discriminator_model.add(Conv2D(128, 3, strides=2, input_shape=(32, 32, 3), padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(256, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(512, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(1024, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Flatten())
        self.discriminator_model.add(Dense(1, kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(Activation('sigmoid'))

        if self.print_model_summary:
            self.discriminator_model.summary()

        return self.discriminator_model

    def build_concatenated_model(self):
        if self.concatenated_model:
            return self.concatenated_model

        self.concatenated_model = Sequential()
        self.concatenated_model.add(self.generator_model)
        self.concatenated_model.add(self.discriminator_model)

        if self.print_model_summary:
            self.concatenated_model.summary()

        return self.concatenated_model

    def build_dc_gan(self):
        self.build_generator_model()
        self.build_discriminator_model()
        self.build_concatenated_model()

        self.discriminator_model.trainable = True
        optimizer = SGD(lr=0.0002)
        self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.discriminator_model.trainable = False
        optimizer = SGD(lr=0.0001)
        self.concatenated_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.discriminator_model.trainable = True

Cifar10Trainer.py:

# Shree KRISHNAya Namaha
# Based on https://towardsdatascience.com/gan-by-example-using-keras-on-tensorflow-backend-1a6d515a60d0

import os

import datetime
import numpy
import time
from keras.datasets import cifar10
from keras.utils import np_utils
from matplotlib import pyplot as plt

import Cifar10Models

log_file_name = 'logs.csv'


class Cifar10Trainer:
    def __init__(self):
        self.x_train, self.y_train = self.get_train_and_test_data()
        self.dc_gan = Cifar10Models.DcGan()
        self.dc_gan.build_dc_gan()

    @staticmethod
    def get_train_and_test_data():
        (x_train, y_train), _ = cifar10.load_data()
        x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 3)
        # Generator output has tanh activation whose range is [-1,1]
        x_train = (x_train.astype('float32') * 2 / 255) - 1
        y_train = np_utils.to_categorical(y_train, 10)
        return x_train, y_train

    def train(self, train_steps=10000, batch_size=128, log_interval=10, save_interval=100,
              output_folder_path='./Trained_Models/'):
        self.initialize_log(output_folder_path)
        self.sample_real_images(output_folder_path)
        for i in range(train_steps):
            # Get real (Database) Images
            images_real = self.x_train[numpy.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]

            # Generate Fake Images
            noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.dc_gan.generator_model.predict(noise)

            # Train discriminator on both real and fake images
            x = numpy.concatenate((images_real, images_fake), axis=0)
            y = numpy.ones([2 * batch_size, 1])
            y[batch_size:, :] = 0
            d_loss = self.dc_gan.discriminator_model.train_on_batch(x, y)

            # Train generator i.e. concatenated model
            noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            y = numpy.ones([batch_size, 1])
            g_loss = self.dc_gan.concatenated_model.train_on_batch(noise, y)

            # Print Logs, Save Models, generate sample images
            if (i + 1) % log_interval == 0:
                self.log_progress(output_folder_path, i + 1, g_loss, d_loss)
            if (i + 1) % save_interval == 0:
                self.save_models(output_folder_path, i + 1)
                self.generate_images(output_folder_path, i + 1)

    @staticmethod
    def initialize_log(output_folder_path):
        log_line = 'Iteration No, Generator Loss, Generator Accuracy, Discriminator Loss, Discriminator Accuracy, ' \
                   'Time\n'
        with open(os.path.join(output_folder_path, log_file_name), 'w') as log_file:
            log_file.write(log_line)

    @staticmethod
    def log_progress(output_folder_path, iteration_no, g_loss, d_loss):
        log_line = '{0:05},{1:2.4f},{2:0.4f},{3:2.4f},{4:0.4f},{5}\n' \
            .format(iteration_no, g_loss[0], g_loss[1], d_loss[0], d_loss[1],
                    datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        with open(os.path.join(output_folder_path, log_file_name), 'a') as log_file:
            log_file.write(log_line)
        print(log_line)

    def save_models(self, output_folder_path, iteration_no):
        self.dc_gan.generator_model.save(
            os.path.join(output_folder_path, 'generator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.discriminator_model.save(
            os.path.join(output_folder_path, 'discriminator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.concatenated_model.save(
            os.path.join(output_folder_path, 'concatenated_model_{0}.h5'.format(iteration_no)))

    def sample_real_images(self, output_folder_path):
        filepath = os.path.join(output_folder_path, 'CIFAR10_Sample_Real_Images.png')
        i = numpy.random.randint(0, self.x_train.shape[0], 16)
        images = self.x_train[i, :, :, :]
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [32, 32, 3])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')

    def generate_images(self, output_folder_path, iteration_no, noise=None):
        filepath = os.path.join(output_folder_path, 'CIFAR10_Gen_Image{0}.png'.format(iteration_no))
        if noise is None:
            noise = numpy.random.uniform(-1, 1, size=[16, 100])
        # Generator output has tanh activation whose range is [-1,1]
        images = (self.dc_gan.generator_model.predict(noise) + 1) / 2
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [32, 32, 3])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')


def main():
    cifar10_trainer = Cifar10Trainer()
    cifar10_trainer.train(train_steps=10000, log_interval=10, save_interval=100)
    del cifar10_trainer.dc_gan
    return


if __name__ == '__main__':
    start_time = time.time()
    print('Program Started at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))))
    main()
    end_time = time.time()
    print('Program Ended at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))))
    print('Total Execution Time: {0}s'.format(datetime.timedelta(seconds=end_time - start_time)))

Некоторые графики приведены ниже:

  1. Оптимизатор дискриминатора: SGD (lr = 0,0001, бета1 = 0,5)
    Оптимизатор генератора: Адам (lr = 0,0001, бета1= 0,5) enter image description here

  2. Оптимизатор дискриминатора: SGD (lr = 0,0001)
    Оптимизатор генератора: SGD (lr = 0,0001) enter image description here

  3. Оптимизатор дискриминатора: SGD (lr = 0,0001)
    Оптимизатор генератора: SGD (lr = 0,001) enter image description here

  4. Оптимизатор дискриминатора: SGD (lr = 0,0001)
    Оптимизатор генератора: SGD (lr = 0,0005) enter image description here

...