Я пытаюсь обучить DC-GAN на наборе данных CIFAR-10.Я использую двоичную перекрестную энтропию как функцию потерь для дискриминатора и генератора (с добавлением необучаемого дискриминатора).Если я тренируюсь с использованием Adam Optimizer, GAN тренируется нормально.Но если я заменим оптимизатор на SGD, обучение будет бесполезным.Точность генератора начинается с некоторой более высокой точки и с итерациями достигает 0 и остается там.Точность дискриминатора начинается с некоторой нижней точки и достигает где-то около 0,5 (ожидаемо, верно?).Особенность в том, что функция потерь генератора увеличивается с итерациями.Я, хотя может быть, шаг слишком высокЯ пытался изменить размер шага.Я пытался использовать импульс с SGD.Во всех этих случаях генератор может уменьшаться или не уменьшаться в начале, но затем точно увеличивается.Итак, я думаю, что в моей модели что-то не так.Я знаю, что тренировать Deep Models сложно, а GAN еще больше, но должна быть какая-то причина / эвристика в отношении того, почему это происходит.Любые входы в ценится.Я новичок в области нейронных сетей, глубокого обучения и, следовательно, также новичок в GAN.
Вот мой код: Cifar10Models.py
from keras import Sequential
from keras.initializers import TruncatedNormal
from keras.layers import Activation, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Flatten, LeakyReLU, Reshape
from keras.optimizers import SGD
class DcGan:
def __init__(self, print_model_summary: bool = False):
self.generator_model = None
self.discriminator_model = None
self.concatenated_model = None
self.print_model_summary = print_model_summary
def build_generator_model(self):
if self.generator_model:
return self.generator_model
self.generator_model = Sequential()
self.generator_model.add(Dense(4 * 4 * 512, input_dim=100,
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Reshape((4, 4, 512)))
self.generator_model.add(Conv2DTranspose(256, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2DTranspose(128, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2DTranspose(64, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2D(3, 3, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(Activation('tanh'))
if self.print_model_summary:
self.generator_model.summary()
return self.generator_model
def build_discriminator_model(self):
if self.discriminator_model:
return self.discriminator_model
self.discriminator_model = Sequential()
self.discriminator_model.add(Conv2D(128, 3, strides=2, input_shape=(32, 32, 3), padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(256, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(512, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(1024, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Flatten())
self.discriminator_model.add(Dense(1, kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(Activation('sigmoid'))
if self.print_model_summary:
self.discriminator_model.summary()
return self.discriminator_model
def build_concatenated_model(self):
if self.concatenated_model:
return self.concatenated_model
self.concatenated_model = Sequential()
self.concatenated_model.add(self.generator_model)
self.concatenated_model.add(self.discriminator_model)
if self.print_model_summary:
self.concatenated_model.summary()
return self.concatenated_model
def build_dc_gan(self):
self.build_generator_model()
self.build_discriminator_model()
self.build_concatenated_model()
self.discriminator_model.trainable = True
optimizer = SGD(lr=0.0002)
self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
self.discriminator_model.trainable = False
optimizer = SGD(lr=0.0001)
self.concatenated_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
self.discriminator_model.trainable = True
Cifar10Trainer.py:
# Shree KRISHNAya Namaha
# Based on https://towardsdatascience.com/gan-by-example-using-keras-on-tensorflow-backend-1a6d515a60d0
import os
import datetime
import numpy
import time
from keras.datasets import cifar10
from keras.utils import np_utils
from matplotlib import pyplot as plt
import Cifar10Models
log_file_name = 'logs.csv'
class Cifar10Trainer:
def __init__(self):
self.x_train, self.y_train = self.get_train_and_test_data()
self.dc_gan = Cifar10Models.DcGan()
self.dc_gan.build_dc_gan()
@staticmethod
def get_train_and_test_data():
(x_train, y_train), _ = cifar10.load_data()
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 3)
# Generator output has tanh activation whose range is [-1,1]
x_train = (x_train.astype('float32') * 2 / 255) - 1
y_train = np_utils.to_categorical(y_train, 10)
return x_train, y_train
def train(self, train_steps=10000, batch_size=128, log_interval=10, save_interval=100,
output_folder_path='./Trained_Models/'):
self.initialize_log(output_folder_path)
self.sample_real_images(output_folder_path)
for i in range(train_steps):
# Get real (Database) Images
images_real = self.x_train[numpy.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]
# Generate Fake Images
noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.dc_gan.generator_model.predict(noise)
# Train discriminator on both real and fake images
x = numpy.concatenate((images_real, images_fake), axis=0)
y = numpy.ones([2 * batch_size, 1])
y[batch_size:, :] = 0
d_loss = self.dc_gan.discriminator_model.train_on_batch(x, y)
# Train generator i.e. concatenated model
noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
y = numpy.ones([batch_size, 1])
g_loss = self.dc_gan.concatenated_model.train_on_batch(noise, y)
# Print Logs, Save Models, generate sample images
if (i + 1) % log_interval == 0:
self.log_progress(output_folder_path, i + 1, g_loss, d_loss)
if (i + 1) % save_interval == 0:
self.save_models(output_folder_path, i + 1)
self.generate_images(output_folder_path, i + 1)
@staticmethod
def initialize_log(output_folder_path):
log_line = 'Iteration No, Generator Loss, Generator Accuracy, Discriminator Loss, Discriminator Accuracy, ' \
'Time\n'
with open(os.path.join(output_folder_path, log_file_name), 'w') as log_file:
log_file.write(log_line)
@staticmethod
def log_progress(output_folder_path, iteration_no, g_loss, d_loss):
log_line = '{0:05},{1:2.4f},{2:0.4f},{3:2.4f},{4:0.4f},{5}\n' \
.format(iteration_no, g_loss[0], g_loss[1], d_loss[0], d_loss[1],
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
with open(os.path.join(output_folder_path, log_file_name), 'a') as log_file:
log_file.write(log_line)
print(log_line)
def save_models(self, output_folder_path, iteration_no):
self.dc_gan.generator_model.save(
os.path.join(output_folder_path, 'generator_model_{0}.h5'.format(iteration_no)))
self.dc_gan.discriminator_model.save(
os.path.join(output_folder_path, 'discriminator_model_{0}.h5'.format(iteration_no)))
self.dc_gan.concatenated_model.save(
os.path.join(output_folder_path, 'concatenated_model_{0}.h5'.format(iteration_no)))
def sample_real_images(self, output_folder_path):
filepath = os.path.join(output_folder_path, 'CIFAR10_Sample_Real_Images.png')
i = numpy.random.randint(0, self.x_train.shape[0], 16)
images = self.x_train[i, :, :, :]
plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i + 1)
image = images[i, :, :, :]
image = numpy.reshape(image, [32, 32, 3])
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.savefig(filepath)
plt.close('all')
def generate_images(self, output_folder_path, iteration_no, noise=None):
filepath = os.path.join(output_folder_path, 'CIFAR10_Gen_Image{0}.png'.format(iteration_no))
if noise is None:
noise = numpy.random.uniform(-1, 1, size=[16, 100])
# Generator output has tanh activation whose range is [-1,1]
images = (self.dc_gan.generator_model.predict(noise) + 1) / 2
plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i + 1)
image = images[i, :, :, :]
image = numpy.reshape(image, [32, 32, 3])
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.savefig(filepath)
plt.close('all')
def main():
cifar10_trainer = Cifar10Trainer()
cifar10_trainer.train(train_steps=10000, log_interval=10, save_interval=100)
del cifar10_trainer.dc_gan
return
if __name__ == '__main__':
start_time = time.time()
print('Program Started at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))))
main()
end_time = time.time()
print('Program Ended at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))))
print('Total Execution Time: {0}s'.format(datetime.timedelta(seconds=end_time - start_time)))
Некоторые графики приведены ниже:
Оптимизатор дискриминатора: SGD (lr = 0,0001, бета1 = 0,5)
Оптимизатор генератора: Адам (lr = 0,0001, бета1= 0,5)
Оптимизатор дискриминатора: SGD (lr = 0,0001)
Оптимизатор генератора: SGD (lr = 0,0001)
Оптимизатор дискриминатора: SGD (lr = 0,0001)
Оптимизатор генератора: SGD (lr = 0,001)
Оптимизатор дискриминатора: SGD (lr = 0,0001)
Оптимизатор генератора: SGD (lr = 0,0005)