Точность дискриминатора и частота дураков генератора близки к 1,0 со слоями нормализации партии - PullRequest
0 голосов
/ 04 июля 2019

Я создаю генеративную состязательную сеть в керасе / тензорном потоке, чтобы генерировать изображения собак.В первый раз, когда я собрал сеть, все заработало как положено.Скорость дурака генератора и точность дискриминатора были обратно соотнесены, как и должно быть.Чтобы улучшить сеть, я добавил несколько BatchNormalization слоев как для генератора, так и для дискриминатора.Мне удалось заморозить слои батнорм на дискриминаторе при обучении генератора, установив layer._per_input_updates = {}. Я подтвердил, что слои были заморожены во время обучения генератора, сравнивая матрицы фактического веса до и после каждой итерации генератора.Однако, когда я распечатываю скорость дурака генератора и точность дискриминатора, они оба сходятся к 1,0.Это происходит только тогда, когда я использую слои нормализации партии.Мои вопросы: а) Значит ли это, что сеть работает неправильно?и б) Если да, то как мне заставить его тренироваться правильно, все еще используя слои BatchNormalization?

Вот мой код (Пожалуйста, игнорируйте многие операторы печати и комментарии, большинство из них предназначались для отладки в какой-то момент):

import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Activation, BatchNormalization,Conv2D, MaxPooling2D, Dense, Dropout, LeakyReLU,Reshape,Input
from tensorflow.keras.optimizers import Adam
import pickle

allimages = []
for i in range(1,21):
  allimages+=pickle.load(open('ImageFile' + str(i) + '.0','rb'))

def genDogs(num):
  allinds = np.random.randint(0,len(allimages),num)
  finaldata = [allimages[i] for i in allinds]
  return np.array(finaldata)

def getWeights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]
def compareWeights(l1,l2):
    assert len(l1)==len(l2)
    for w1,w2 in zip(l1,l2):
        if not np.array_equal(w1,w2):
            print('blaaaahhhh! No')
            return False
    print('ALL WEIGHTS SAME')
    return True

#run_opts = tf.RunOptions(report_tensor_allocations_upon_oom = True)
#generator
generator = Sequential()
#generator input can be a 100-len vector and output a 4096*3=12288 length vector
generator.add(Dense(512,input_shape=[100],activation='tanh'))
generator.add(Dense(2048))
generator.add(BatchNormalization())
generator.add(LeakyReLU())
generator.add(Dense(4096))
generator.add(BatchNormalization())
generator.add(Activation('tanh'))
generator.add(Dense(12288,activation='sigmoid'))
generator.add(Reshape([64,64,3]))
generator.compile(loss='binary_crossentropy',optimizer=Adam())#,options=run_opts)


#discriminator
discriminator = Sequential()
#input shape [batch,64,64,3]
discriminator.add(Conv2D(64,(4,4),padding='same',input_shape=[64,64,3])) #outputs [None,64,64,64]
discriminator.add(MaxPooling2D()) #new dims [None, 32,32,64]
discriminator.add(BatchNormalization())
discriminator.add(Dropout(.1))
discriminator.add(LeakyReLU())
discriminator.add(Conv2D(128,(2,2),padding='same')) #[None,32,32,128]
discriminator.add(MaxPooling2D()) #[None,16,16,128]
discriminator.add(BatchNormalization())
discriminator.add(Dropout(.1))
discriminator.add(LeakyReLU())
discriminator.add(Conv2D(32,(4,4),padding='valid',activation='tanh')) #shape [None,13,13,32]
discriminator.add(BatchNormalization())
discriminator.add(Reshape([5408]))#5408 size
discriminator.add(Dense(256,activation=LeakyReLU()))
discriminator.add(Dense(1,activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy',optimizer=Adam(),metrics=['accuracy'])#,options=run_opts)

#combined
discriminator.trainable = False
for layer in discriminator.layers:
    layer.trainable = False
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        layer._per_input_updates = {}
gan_input = Input([100])
outs = discriminator(generator(gan_input))
combined = Model(inputs=gan_input,outputs=outs)
combined.compile(loss='binary_crossentropy',metrics=['accuracy'],optimizer=Adam())#,options=run_opts)
print('###########################################loaded models!')
def showIm():
  noise1 = np.random.random([1,100])
  generated = generator.predict_on_batch(noise1)
  plt.imshow(generated[0])

def iteration(halfbatch):

  #train the discriminator once first
  discriminator.trainable=True
  noise1 = np.random.random([halfbatch,100])
  generated = generator.predict_on_batch(noise1)
  dogs = genDogs(halfbatch)
  together = np.concatenate((generated,dogs),axis=0)
  y = np.array([0 for _ in range(halfbatch)] + [1 for _ in range(halfbatch)])
  #print(discriminator.predict(together))
  #print('#######################################fitting discriminator!')
  outs = discriminator.train_on_batch(together,y)
  print('discriminator loss: ' + str(outs[0]) + ', discriminator accuracy: '+ str(outs[1]))

  preW = getWeights(discriminator)
  #train the combined network
  for _ in range(1):
    noise2 = np.random.random([halfbatch,100])
    labels = np.ones([halfbatch])
    #print(combined.predict(noise2))
    #print('#########################################fitting generator!')
    outs2 = combined.train_on_batch(noise2,labels)
    postW = getWeights(discriminator)
    compareWeights(preW,postW)
    print('generator loss: ' + str(outs2[0]) + ', generator fool rate: ' + str(outs2[1]))

for i in range(100):
    iteration(64)
generator.save('generator0.h5')

Вот пример выходных данных, которые я запускаю:

generator loss: 1.7851067, generator fool rate: 0.8125
WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did
 you set `model.trainable` without calling `model.compile` after ?
discriminator loss: 1.0960464e-07, discriminator accuracy: 1.0
ALL WEIGHTS SAME
generator loss: 1.630374, generator fool rate: 0.84375
WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did
 you set `model.trainable` without calling `model.compile` after ?
discriminator loss: 1.0960464e-07, discriminator accuracy: 1.0
ALL WEIGHTS SAME

Как видите, веса дискриминатора не тренируются во время итераций генератора.Тем не менее, скорость дурака генератора и точность дискриминатора сходятся к 1,0.Этого не произошло, когда я избавился от BatchNormalization слоев.

...