GAN с нормой партии действует очень странно, и дискриминатор, и генератор получают нулевые потери - PullRequest
0 голосов
/ 10 октября 2019

Я тренирую модель DCGAN с tenorflow.keras, и я добавил слои BatchNormalization в генератор и дискриминатор. Я обучаю gan следующими шагами: 1. обучаем дискриминатор реальными изображениями и изображениями из генератора (используя generator.predict) 2. обучаем состязательную сеть (скомпилированную с помощьюcriminator.trainable = False)

Затем я обнаружил, что послеЗа несколько раундов тренировочные потери, возвращаемые train_on_batch () как генератора, так и дискриминатора, обнуляются. Но когда я использую test_on_batch (), потери все еще огромны для генератора. И сгенерированные изображения все беспорядочные.

Сначала я подумал, что это потому, что на шаге 2., упомянутом выше при обучении состязательной сети, вход дискриминатора, содержащий только поддельные изображения, заставляет слои пакетной нормализации получитьразличное распределение как шаг 1. когда подавались как поддельные, так и реальные изображения. Но даже если я удалил все слои пакетной нормализации в дискриминаторе, та же проблема все еще существует. Только после удаления всех слоев нормализации партии проблема исчезнет. Также я узнал, что существование слоев Dropout не имеет значения. Мне интересно, почему нормализация партии может вызвать такую ​​проблему, даже если в генераторе подается шум с таким же распределением.

# Model definition
class DCGAN_128:
    def __init__(self, hidden_dim):
        generator = M.Sequential()
        generator.add(L.Dense(128 * 8 * 8, input_shape=[hidden_dim]))
        generator.add(L.Reshape([8, 8, 128]))
        generator.add(L.UpSampling2D())  # [8, 8, 128]
        generator.add(L.Conv2D(128, kernel_size=3, padding="same"))  # [16, 16, 128]
        generator.add(L.LayerNormalization())  # 4
        generator.add(L.ReLU())
        generator.add(L.UpSampling2D())  # [32, 32, 128]
        generator.add(L.Conv2D(64, kernel_size=5, padding="same"))   # [32, 32, 64]
        generator.add(L.LayerNormalization())  # 8
        generator.add(L.ReLU())
        generator.add(L.UpSampling2D())  # [64, 64, 128]
        generator.add(L.Conv2D(32, kernel_size=7, padding="same"))   # [64, 64, 32]
        generator.add(L.LayerNormalization())  # 12
        generator.add(L.ReLU())
        generator.add(L.UpSampling2D())  # [128, 128, 32]
        generator.add(L.Conv2D(3, kernel_size=3, padding="same", activation=A.sigmoid))   # [128, 128, 3]

        discriminator = M.Sequential()
        discriminator.add(L.Conv2D(32, kernel_size=5, strides=2, padding="same", input_shape=[128, 128, 3]))
        discriminator.add(L.LeakyReLU())
        # discriminator.add(L.Dropout(0.25))  # [64, 64, 32]
        discriminator.add(L.Conv2D(64, kernel_size=3, strides=2, padding="same"))
        # discriminator.add(L.BatchNormalization(epsilon=1e-5))  # 4
        discriminator.add(L.LeakyReLU())
        # discriminator.add(L.Dropout(0.25))  # [32, 32, 64]
        discriminator.add(L.Conv2D(128, kernel_size=3, strides=2, padding="same"))
        discriminator.add(L.LayerNormalization())   # 8
        discriminator.add(L.LeakyReLU())    # [16, 16, 128]
        discriminator.add(L.Dropout(0.25))
        discriminator.add(L.Conv2D(256, kernel_size=3, strides=2, padding="same"))
        discriminator.add(L.LayerNormalization())   # 12
        discriminator.add(L.LeakyReLU())    # [8, 8, 256]
        discriminator.add(L.Dropout(0.25))
        discriminator.add(L.Conv2D(512, kernel_size=3, strides=2, padding="same"))
        discriminator.add(L.LeakyReLU())    # [4, 4, 512]
        discriminator.add(L.Flatten())
        discriminator.add(L.Dense(1, activation=A.sigmoid))
        self.model_gen = generator
        self.model_dis = discriminator

        self.adv_input = L.Input([hidden_dim])
        self.adv_output = discriminator(generator(self.adv_input))
        self.model_adversarial = M.Model(self.adv_input, self.adv_output)




# Training
dcgan = hidden_dim = 100
DCGAN_128(hidden_dim)
data_loader = AnimeFacesLoader([128, 128])
batch_size = 32
n_rounds = 40000
dis_model = dcgan.model_dis
gen_model = dcgan.model_gen
adv_model = dcgan.model_adversarial
gen_model.summary()
adv_model.summary()


dis_model.compile(Opt.Adam(0.0002), Lo.binary_crossentropy)
dis_model.trainable = False
adv_model.compile(Opt.Adam(0.0002), Lo.binary_crossentropy)

layer_outputs = [layer.output for layer in dis_model.layers]
visual_model = tf.keras.Model(dis_model.input, layer_outputs)



for rounds in range(n_rounds):
    # Get output images
    if rounds % 100 == 0 and rounds > 0:
        noise = np.random.uniform(-1, 1, [16, hidden_dim])
        tiled_images = np.zeros([4*128, 4*128, 3]).astype(np.uint8)
        generated_imgs = gen_model.predict(noise)
        generated_imgs *= 256
        generated_imgs = generated_imgs.astype(np.uint8)
        for i in range(16):
            tiled_images[int(i / 4)*128: int(i / 4)*128 + 128,
                         int(i % 4)*128: int(i % 4)*128 + 128, :] = generated_imgs[i, :, :, :]
        Image.fromarray(tiled_images).save("Output/DCGAN/" + "rounds_{0}.jpg".format(rounds))


    '''
        layer_visualization = visual_model.predict(generated_imgs[:1])
        for i in range(len(layer_visualization)):
            plt.imshow(layer_visualization[i][0, :, :, 0])
            plt.show()
    '''

    # train discriminator on real & fake images
    real_imgs = data_loader.get_batch(batch_size)
    real_ys = np.ones([batch_size, 1])
    noise = np.random.uniform(-1, 1, [batch_size, hidden_dim])
    fake_ys = np.zeros([batch_size, 1])
    fake_imgs = gen_model.predict(noise)
    imgs = np.concatenate([real_imgs, fake_imgs], axis=0)
    ys = np.concatenate([real_ys, fake_ys], axis=0)


    loss_dis = dis_model.train_on_batch(imgs, ys)
    print("Round {}, Loss dis:{:.4f}".format(rounds, loss_dis))
    loss_dis_test = dis_model.test_on_batch(imgs, ys)
    print(loss_dis_test)

    noise = np.random.uniform(-1, 1, [batch_size, hidden_dim])
    fake_ys = np.ones([batch_size, 1])

    loss_gen = adv_model.train_on_batch(noise, fake_ys)
    print("Round {}, Loss gen:{:.4f}".format(rounds, loss_gen))
    loss_gen_test = adv_model.test_on_batch(noise, fake_ys)
    print(loss_gen_test)
...