Все изображения в партии одинаковы во время обучения GAN - PullRequest
1 голос
/ 10 июня 2019
def train(self,gray_scale_image_dataset,color_image_dataset,test_image):
    SEED= 50
    random.seed(SEED)
    generator = self.generator_model()
    discriminator = self.discriminator_model()

    gray_scale_images = gray_scale_image_dataset
    colored_images = color_image_dataset

    gen_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
    dis_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
    for eachEpoch in range(self.epochs):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            for i in range(20):
                random.shuffle(gray_scale_image_dataset)
                random.shuffle(color_image_dataset)
            gray_scale_dataset_image = gray_scale_images[:self.batch_size]
            print(gray_scale_dataset_image.shape)
            color_dataset_image_batch = colored_images[:self.batch_size]
            #lets see which images are being trained

            self.draw_train_images(color_dataset_image_batch) 
            generated_image = generator(gray_scale_dataset_image)
            real_output = discriminator(color_dataset_image_batch)

            fake_output = discriminator(generated_image)
            print("What  Discriminator Thought about real_output = {} and fake output = {}".format(real_output[0],fake_output[0]))
            gen_loss = self.generator_loss(fake_output,generated_image,color_dataset_image_batch)
            dis_loss = self.discriminator_loss(fake_output,real_output)
            print("GEN LOSS {} and DISC = {}".format(gen_loss[0],dis_loss[0]))

        gen_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
        disc_gradients = disc_tape.gradient(dis_loss,discriminator.trainable_variables)
        gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
        dis_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

        print ("EPOCHS COMPLETED = {} ".format(eachEpoch))
        self.draw_images(generator,test_image)

Это моя функция поезда, которая фактически обучает сеть ГАН.Генератор и дискриминатор - это две разные сети, которые были созданы после статьи, поэтому с этим нет проблем.Моя проблема с изображениями, которые я передаю.Я проверил изображения, оба gray_scale_image_dataset,color_image_dataset, и они хорошо работают.Но когда я перехожу к функции self.draw_train_images и пытаюсь нарисовать их в matplotlib, в сетке отображается только первое изображение.Вся сетка заполнена первым изображением, и только это изображение используется для обучения данных, поэтому я получаю много ошибок.Любая помощь по этому вопросу?Где я все испортил?

...