У меня есть модель условного GAN (CGAN) в Керасе:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
class GAN(object):
def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.N_CLASSES = 10 # total number of possible classes in the data
self.OPTIMIZER = Adam(lr, 0.5)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
self.D.trainable = False # prevent stacked D from training; https://github.com/eriklindernoren/Keras-GAN/issues/73
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs
label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class
# embed label in size of latent dimension
h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
label_embedding = Flatten()(h)
# unified model
h = multiply([noise, label_embedding])
h = Dense(256)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = Model(inputs=[noise, label], outputs=[o])
model.summary()
return model
def discriminator(self):
image = Input((self.SHAPE))
label = Input((1,), dtype='int32')
# embed the label in the shape of an image (flattened)
h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
label_embedding = Flatten()(h)
# parse out the image
img = Flatten()(image)
# unified model
h = multiply([img, label_embedding])
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
o = Dense(1, activation='sigmoid')(h)
model = Model(inputs=[image, label], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
noise = Input((self.LATENT_DIM,)) # noise input
label = Input((1,)) # conditional input
img = self.G([noise, label])
valid = self.D([img, label])
model = Model(inputs=[noise, label], outputs=[valid])
model.summary()
return model
def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100):
for i in range(epochs):
# train the discriminator
idx = np.random.randint(0, X_train.shape[0], batch)
imgs, labels = X_train[idx], Y_train[idx]
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
fake_imgs = self.G.predict([noise, labels])
d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5
# train the generator
sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))
if i % save_interval == 0:
print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
self.plot_images(save_to_disk=True, filename=filename)
def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
if not filename: filename = 'mnist.png'
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
images = self.G.predict([noise, classes])
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(os.path.join('images', filename))
plt.close('all')
else:
fig.show()
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train)
Моя цель - периодически останавливать дискриминатор, чтобы он не мог учиться.(Это для некоторой экспериментальной работы.) Однако я не могу найти способ фактически обновить атрибут .trainable
gan.D
после компиляции модели.Я пытался вручную изменять атрибут периодически, но неважно, что дискриминатор продолжает изучать.
Возможно ли на самом деле обновить атрибут trainable
модели после компиляции этой модели?Если это так, я был бы признателен за простой пример, как это сделать!