Keras mnist gan не генерирует правильного вывода. Понятия не имею, что я делаю не так - PullRequest
0 голосов
/ 06 октября 2019

Ниже приведен мой код для генерации цифр mnist, и я включил потери и выходные данные для пояснения. Приветствуется любая помощь

. Выходные потери не очень помогают понять, почему выходные изображения являются помехами. Прошу прощения за нулевую документацию.

from keras.activations import sigmoid
from keras.optimizers import Adam
from keras.models import Sequential
from keras.datasets import mnist
from keras.utils import normalize
from matplotlib import pyplot as plt

import numpy as np

from sklearn.utils import shuffle

def discriminator():
    model = Sequential()
    model.add(InputLayer(input_shape=(28,28,1)))
    model.add(Conv2D(64, 3, padding='same'))
    model.add(LeakyReLU(alpha=0.1))
    model.add(Dropout(0.3))
    model.add(Conv2D(32, 3, strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.1))
    model.add(Dropout(0.3))
    model.add(Conv2D(25, 3, strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.1))
    model.add(Dropout(0.3))
    model.add(Conv2D(25, 3, strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.1))
    model.add(Flatten())

    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')

    model.summary()

    return model


def generator():
    model = Sequential()
    model.add(Dense(150*7*7, input_shape=(100,) ))
    model.add(LeakyReLU(0.1))

    model.add(Reshape((7,7,150)))

    model.add(Conv2DTranspose(256, (2,2), strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.1))

    model.add(Conv2DTranspose(128, (3,3), strides=2, padding='same'))
    model.add(LeakyReLU(0.1))

    model.add(Conv2D(1, kernel_size=(7,7), padding='same', activation='sigmoid'))
    model.add(LeakyReLU(0.2))
    model.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')

    model.summary()

    return model


def gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002), metrics=['accuracy'])

    model.add(discriminator)

    model.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')
    model.summary()

    return model


def noise(nsamples, dim=100):
    x = np.random.randn(nsamples,dim)
    return x


def fake_images(generator, nsampels,):
    x = noise(nsampels)
    output_x = generator.predict(x)
    output_y = np.zeros((nsampels,1))
    return output_x, output_y


def real_images(nsampels,dataset):
    x = np.random.randint(0, dataset.shape[0], nsampels)
    x = dataset[x].reshape(-1,28,28,1)
    y = np.ones((nsampels, 1))        
    return x, y


def save_plot(epoch, gen_imgs):
    r, c = 5, 5
    fig, axs = plt.subplots(r, c)
    cnt = -1
    for i in range(r):
        for j in range(c):
            cnt += 1
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')

    fig.savefig("image/mnist_%d.png" % epoch)
    plt.close()

def dis_performence(discriminator, generator, epoch, dataset):
    x, y = real_images(100, dataset)
    _,acc_real = discriminator.evaluate(x,y,verbose=0)
    X, Y = fake_images(generator, 100)
    _,acc_fake = discriminator.evaluate(X,Y,verbose=0)

    print(f'\n Accuracy for real: {acc_real}%,   accuracy for fake: {acc_fake}% \n')
    generator.save(f'generator {epoch}.h5')
    save_plot(epoch,X)


def train(generator, discriminator,n_epoch, gan, dataset, data_per_batch=250):

    no_of_batches_per_epoch = dataset.shape[0]//data_per_batch
    for epoch in range(n_epoch):
        for i in range(no_of_batches_per_epoch):
            x_real, y_real = real_images(data_per_batch//2, dataset)
            x_fake, y_fake = fake_images(generator, data_per_batch//2)

            X_dis, Y_dis = np.vstack((x_real,x_fake)), np.vstack((y_real,y_fake))

            loss_dis, _ = discriminator.train_on_batch(X_dis, Y_dis)

            X_gan = noise(data_per_batch)
            Y_gan = np.ones((data_per_batch,1))

            gan_loss = gan.train_on_batch(X_gan,Y_gan)

        print(f'epoch: {epoch} / {n_epoch}, g_loss: {gan_loss}, d_loss: {loss_dis}')
        if epoch%10 == 0:
            dis_performence(discriminator, generator, epoch, dataset)


dis = discriminator()
gen = generator()
gan = gan(generator=gen, discriminator=dis)


(train_data,_),(_,_) = mnist.load_data()
train_data = normalize(train_data)
train(gen, dis, 100, gan, train_data,)

экранная крышка всех потерь модели гана и модели дискриминатора.

выходпосле 20 эпохи. Вывод выглядит так, как будто это все одно и то же изображение, но я не знаю, почему это происходит (это просто похоже на шум)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...