Какова лучшая условная порождающая модель? И какие-либо советы для реализации? - PullRequest
0 голосов
/ 23 февраля 2020

Я работаю над генеративными моделями, чтобы сгенерировать несколько образцов для моего исследования. По какой-то причине мне пришлось распутать скрытое пространство, поэтому я попытался применить идеи infoGAN и условного GAN к различным архитектурам.

Я думал, что это легко, но каким-то образом модель дала мне плохие результаты. Даже ту же архитектуру, которая дает хорошие результаты без условий, при каждом добавлении условий становится трудно обучаться.

Итак, мне интересно, какова лучшая условная архитектура, и кто-то может дать советы, как помочь с моими моделями? Ниже приведены некоторые подробности:

Изображения:

Сгенерированные образцы с оригинальным DCGAN (без условий, хорошо)

Сгенерированные образцы с условным условием DCGAN (ошибка)

Сгенерированные образцы с условным WGAN-GP (плохое качество)

Коды:

Библиотеки

from functools import partial
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Sequential, Input
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, UpSampling2D, AveragePooling2D
from tensorflow.keras.layers import Reshape, Flatten, Concatenate, BatchNormalization, Activation
from tensorflow.keras.layers import ReLU, LeakyReLU
import numpy as np
from scipy.linalg import sqrtm
from utils import *
import time
import os

Модель

class cWGAN4Rank2(object):
    def __init__(self, z_dim=1, batch_size=64, gen_lr=1e-4, disc_lr=6e-5, lam=10., n_critic=5):
        self.lam = lam
        self.n_critic = n_critic
        # Dataset
        self.batch_size = batch_size
        dataset, num_samples = load_rank2_dataset(False)
        self.dataset = dataset.shuffle(num_samples).batch(batch_size)
        elem_spec = dataset.element_spec
        # Data shapes and dimensions
        self.image_size = elem_spec[0].shape[0]
        self.num_channel = elem_spec[0].shape[2]
        self.image_shape = (self.image_size, self.image_size, self.num_channel)
        self.z_dim = z_dim
        # Models
        self.D = self.build_discriminator(32, 5)
        self.G = self.build_generator(34, 5)
        self.I = keras.models.load_model('/data/inception4rank2.h5')
        c = 0
        for w in self.D.get_weights():
            c += w.size
        print(c)
        c = 0
        for w in self.G.get_weights():
            c += w.size
        print(c)
        # Optimizers
        self.D_opt = keras.optimizers.Adam(disc_lr, .5, .9)
        self.G_opt = keras.optimizers.Adam(gen_lr, .5, .9)
        # Test noise
        self.num_ex = 64
        z = tf.random.uniform(shape=(self.num_ex, self.z_dim), minval=-1, maxval=1)
        c = tf.reshape(tf.constant(np.linspace(0,1,self.num_ex, dtype='float32')), (self.num_ex,1))
        self.seed = tf.concat([z,c], axis=1)
        # Records
        self.epoch = 0
        self.step = 0
        self.steps = []
        self.steps_per_epoch = []
        self.D_loss_per_step = []
        self.D_loss_per_epoch = []
        self.G_loss_per_step = []
        self.G_loss_per_epoch = []
        self.FID_per_step = []
        self.FID_per_epoch = []

    def build_discriminator(self, filters, kernel_size):
        c = Input(shape=(1,))
        t = Dense(32*32)(c)
        t = Reshape((32,32,1))(t)
        t = Activation(tf.nn.leaky_relu)(BatchNormalization()(t))

        x = Input(shape=self.image_shape, name='image_input') # 32x32x1
        y = Conv2D(filters, kernel_size, padding='same')(x)
        y = Activation(tf.nn.leaky_relu)(BatchNormalization()(y))
        y = Concatenate()([y,t])
        y = Conv2D(2*filters, kernel_size, strides=(2,2), padding='same')(y) # 16x16xn
        y = Activation(tf.nn.leaky_relu)(BatchNormalization()(y))
        y = Conv2D(4*filters, kernel_size, strides=(2,2), padding='same')(y) # 8x8xn
        y = Activation(tf.nn.leaky_relu)(BatchNormalization()(y))
        z = Dense(1)(Flatten()(y))
        return Model([x,c], z)

    def build_generator(self, filters, kernel_size):
        x = Input(shape=(self.z_dim+1,))
        y = Reshape((8,8,filters))(Dense(8*8*filters)(x)) # 8x8xn
        y = Activation(tf.nn.relu)(BatchNormalization()(y))
        y = Conv2DTranspose(2*filters, kernel_size, strides=(2,2), padding='same')(y) # 16x16xn
        y = Activation(tf.nn.relu)(BatchNormalization()(y))
        y = Conv2DTranspose(4*filters, kernel_size, strides=(2,2), padding='same')(y) # 32x32xn
        y = Activation(tf.nn.relu)(BatchNormalization()(y))
        z = Conv2D(1, kernel_size, padding='same', activation='tanh')(y)
        return Model(x, z)

    def train(self, num_epochs):
        for e in range(num_epochs):
            self.epoch += 1
            t_epoch = time.time()
            for batch in self.dataset:
                self.step += 1
                for _ in range(self.n_critic):
                    D_loss = self.train_discriminator(batch)
                self.D_loss_per_step.append(D_loss.numpy())
                G_loss = self.train_generator()
                self.G_loss_per_step.append(G_loss.numpy())

                self.steps.append(self.step)
            self.steps_per_epoch.append(self.step)
            self.D_loss_per_epoch.append(D_loss)
            self.G_loss_per_epoch.append(G_loss)
            FID = self.fid_eval(batch)
            self.FID_per_epoch.append(FID)

            if (e+1) % 10 == 0:
                display.clear_output(True)
                self.visualization()

            print('[{: 3d}/{: 3d}]epoch, D_loss= {:.4e}, G_loss= {:.4e}, time= {:.2f} sec'.format(e+1, num_epochs, D_loss, G_loss, time.time()-t_epoch))

    @tf.function
    def train_discriminator(self, batch):
        X_real = 2 * batch[0] - 1
        c_real = tf.reduce_mean(batch[0], axis=(1,2))
        z = tf.random.uniform(shape=(self.batch_size, self.z_dim), minval=-1, maxval=1)
        c_fake = tf.random.uniform(shape=(self.batch_size,1))
        h = tf.concat([z, c_fake], axis=1)
        with tf.GradientTape() as tape:
            X_fake = self.G(h, training=True)
            y_fake = self.D((X_fake,c_fake), training=True)
            y_real = self.D((X_real,c_real), training=True)
            real_loss = tf.reduce_mean(y_real)
            fake_loss = tf.reduce_mean(y_fake)
            D_loss = fake_loss - real_loss
            gp = self.gradient_penalty(partial(self.D, training=True), X_real, X_fake, c_real, c_fake)
            loss = D_loss + self.lam * gp
        grad = tape.gradient(loss, self.D.trainable_variables)
        self.D_opt.apply_gradients(zip(grad, self.D.trainable_variables))
        return D_loss

    @tf.function
    def train_generator(self):
        z = tf.random.uniform(shape=(self.batch_size, self.z_dim), minval=-1, maxval=1)
        c_fake = tf.random.uniform(shape=(self.batch_size,1))
        h = tf.concat([z, c_fake], axis=1)
        with tf.GradientTape() as tape:
            X_fake = self.G(h, training=True)
            y_fake = self.D((X_fake,c_fake), training=True)
            fake_loss = tf.reduce_mean(y_fake)
            #vol_loss = .5 * tf.reduce_mean(tf.square(tf.reduce_mean((X_fake+1)/2, axis=(1,2)) - c_fake))
            loss = -fake_loss# + vol_loss
        grad = tape.gradient(loss, self.G.trainable_variables)
        self.G_opt.apply_gradients(zip(grad, self.G.trainable_variables))
        return loss

    @tf.function
    def gradient_penalty(self, f, X_real, X_fake, c_real, c_fake):
        alpha = tf.random.uniform((self.batch_size, 1, 1, 1))
        beta = tf.reshape(alpha, (-1,1))
        inter = (alpha * X_real + (1 - alpha) * X_fake, beta * c_real + (1 - beta) * c_fake)
        with tf.GradientTape() as tape:
            tape.watch(inter)
            pred = f(inter)
        grad = tape.gradient(pred, inter)
        slopes = tf.sqrt(tf.reduce_sum(tf.square(grad[0]), axis=(1,2,3)) + tf.reduce_sum(tf.square(grad[1]), axis=1))
        return tf.reduce_mean((slopes - 1.)**2)

    def fid_eval(self,batch):
        X_real = 2 * batch[0] - 1
        z = tf.random.uniform(shape=(self.batch_size, self.z_dim), minval=-1, maxval=1)
        c_fake = tf.random.uniform(shape=(self.batch_size,1))
        h = tf.concat([z, c_fake], axis=1)
        X_fake = (self.G(h, training=False) + 1) / 2
        a_real = self.I(X_real, training=False)
        a_fake = self.I(X_fake, training=False)
        mu_real, sig_real = tf.reduce_mean(a_real, axis=0), np.cov(a_real, rowvar=False)
        mu_fake, sig_fake = tf.reduce_mean(a_fake, axis=0), np.cov(a_fake, rowvar=False)
        ssdiff = np.sum((mu_real - mu_fake)**2)
        covmean = sqrtm(sig_real.dot(sig_fake))
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        return ssdiff + np.trace(sig_real + sig_fake - 2.0 * covmean)

    def visualization(self):
        fig, ax = plt.subplots(1,3, figsize=(20,5))
        X_test = self.G(self.seed, training=False)
        nrow = int(np.floor(np.sqrt(self.num_ex)))
        ncol = int(np.ceil(self.num_ex / nrow))
        X_tile = np.zeros(shape=((self.image_size+2)*nrow,(self.image_size+2)*ncol))
        for row in range(nrow):
            for col in range(ncol):
                ind = ncol * row + col
                irow, icol = np.meshgrid(np.arange(self.image_size+2) + row*(self.image_size+2),
                                         np.arange(self.image_size+2) + col*(self.image_size+2))

                try:
                    X_tile[irow,icol] = tf.pad(X_test[ind,:,:,0],[[1,1],[1,1]])
                except:
                    None
        ax[0].imshow(1-X_tile, cmap='gray')
        ax[1].plot(self.steps, self.D_loss_per_step, label='discriminator')
        ax[1].plot(self.steps, self.G_loss_per_step, label='generator')
        ax[1].legend()
        ax[2].plot(self.steps_per_epoch, self.FID_per_epoch, label='FID')
        ax[2].legend()
        plt.show()
...