Keras: обновить атрибут `trainable` после компиляции модели - PullRequest
0 голосов
/ 15 апреля 2019

У меня есть модель условного GAN (CGAN) в Керасе:

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')

if not os.path.exists('images'): os.makedirs('images')

class GAN(object):
  def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
    self.WIDTH = int(width) # width of input images
    self.HEIGHT = int(height) # height of input images
    self.CHANNELS = int(channels) # n color channels in images
    self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
    self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
    self.N_CLASSES = 10 # total number of possible classes in the data
    self.OPTIMIZER = Adam(lr, 0.5)

    # generator
    self.G = self.generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

    # discriminator
    self.D = self.discriminator()
    self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
    self.D.trainable = False # prevent stacked D from training; https://github.com/eriklindernoren/Keras-GAN/issues/73

    # stacked generator + discriminator
    self.stacked_G_D = self.stacked_G_D()
    self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

  def generator(self):
    noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs 
    label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class

    # embed label in size of latent dimension
    h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
    label_embedding = Flatten()(h)

    # unified model
    h = multiply([noise, label_embedding])
    h = Dense(256)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(1024)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
    o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)

    model = Model(inputs=[noise, label], outputs=[o])
    model.summary()
    return model

  def discriminator(self):
    image = Input((self.SHAPE))
    label = Input((1,), dtype='int32')

    # embed the label in the shape of an image (flattened)
    h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
    label_embedding = Flatten()(h)

    # parse out the image
    img = Flatten()(image)

    # unified model
    h = multiply([img, label_embedding])
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    o = Dense(1, activation='sigmoid')(h)

    model = Model(inputs=[image, label], outputs=[o])
    model.summary()
    return model

  def stacked_G_D(self):
    noise = Input((self.LATENT_DIM,)) # noise input
    label = Input((1,)) # conditional input
    img = self.G([noise, label])
    valid = self.D([img, label])
    model = Model(inputs=[noise, label], outputs=[valid])
    model.summary()
    return model

  def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100):
    for i in range(epochs):

      # train the discriminator
      idx = np.random.randint(0, X_train.shape[0], batch)
      imgs, labels = X_train[idx], Y_train[idx]
      noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
      fake_imgs = self.G.predict([noise, labels])
      d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
      d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
      d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5

      # train the generator
      sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
      g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))

      if i % save_interval == 0: 
        print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
        filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
        self.plot_images(save_to_disk=True, filename=filename)

  def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
    if not filename: filename = 'mnist.png'
    noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
    classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
    images = self.G.predict([noise, classes])
    cols = np.ceil(n_images/rows) # n_cols in grid
    fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))    
    for i in range(n_images):
      ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
      image = np.reshape(images[i], [28, 28])  
      plt.imshow(image)
    fig.subplots_adjust(hspace=0, wspace=0)
    if save_to_disk:
      fig.savefig(os.path.join('images', filename))
      plt.close('all')
    else:
      fig.show()


(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train)

Моя цель - периодически останавливать дискриминатор, чтобы он не мог учиться.(Это для некоторой экспериментальной работы.) Однако я не могу найти способ фактически обновить атрибут .trainable gan.D после компиляции модели.Я пытался вручную изменять атрибут периодически, но неважно, что дискриминатор продолжает изучать.

Возможно ли на самом деле обновить атрибут trainable модели после компиляции этой модели?Если это так, я был бы признателен за простой пример, как это сделать!

1 Ответ

0 голосов
/ 15 апреля 2019

Ах, вы можете обновить атрибут .trainable на модели после компиляции модели, вам просто нужно перекомпилировать модель:

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')

if not os.path.exists('images'): os.makedirs('images')

class GAN(object):
  def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
    self.WIDTH = int(width) # width of input images
    self.HEIGHT = int(height) # height of input images
    self.CHANNELS = int(channels) # n color channels in images
    self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
    self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
    self.N_CLASSES = 10 # total number of possible classes in the data
    self.OPTIMIZER = Adam(lr, 0.5)

    # generator
    self.G = self.generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

    # discriminator
    self.D = self.discriminator()
    self.D.trainable = False # normally this line follows the initial compilation of the D
    self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])

    # stacked generator + discriminator
    self.stacked_G_D = self.stacked_G_D()
    self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

  def generator(self):
    noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs 
    label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class

    # embed label in size of latent dimension
    h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
    label_embedding = Flatten()(h)

    # unified model
    h = multiply([noise, label_embedding])
    h = Dense(256)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(1024)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = BatchNormalization(momentum=0.8)(h)
    h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
    o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)

    model = Model(inputs=[noise, label], outputs=[o])
    model.summary()
    return model

  def discriminator(self):
    image = Input((self.SHAPE))
    label = Input((1,), dtype='int32')

    # embed the label in the shape of an image (flattened)
    h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
    label_embedding = Flatten()(h)

    # parse out the image
    img = Flatten()(image)

    # unified model
    h = multiply([img, label_embedding])
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    h = Dense(512)(h)
    h = LeakyReLU(alpha=0.2)(h)
    h = Dropout(0.4)(h)
    o = Dense(1, activation='sigmoid')(h)

    model = Model(inputs=[image, label], outputs=[o])
    model.summary()
    return model

  def stacked_G_D(self):
    noise = Input((self.LATENT_DIM,)) # noise input
    label = Input((1,)) # conditional input
    img = self.G([noise, label])
    valid = self.D([img, label])
    model = Model(inputs=[noise, label], outputs=[valid])
    model.summary()
    return model

  def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100, toggle_D_trainable=None):
    for i in range(epochs):

      # train the discriminator
      idx = np.random.randint(0, X_train.shape[0], batch)
      imgs, labels = X_train[idx], Y_train[idx]
      noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
      fake_imgs = self.G.predict([noise, labels])
      d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
      d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
      d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5

      # train the generator
      sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
      g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))

      if i % save_interval == 0: 
        print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
        filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
        self.plot_images(save_to_disk=True, filename=filename)
      if i > 0 and toggle_D_trainable and i % toggle_D_trainable == 0:
        self.D.trainable = False if self.D.trainable else True
        self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])

  def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
    if not filename: filename = 'mnist.png'
    noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
    classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
    images = self.G.predict([noise, classes])
    cols = np.ceil(n_images/rows) # n_cols in grid
    fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))    
    for i in range(n_images):
      ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
      image = np.reshape(images[i], [28, 28])  
      plt.imshow(image)
    fig.subplots_adjust(hspace=0, wspace=0)
    if save_to_disk:
      fig.savefig(os.path.join('images', filename))
      plt.close('all')
    else:
      fig.show()


(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train, save_interval=100, toggle_D_trainable=1000)
...