def train(self,gray_scale_image_dataset,color_image_dataset,test_image):
SEED= 50
random.seed(SEED)
generator = self.generator_model()
discriminator = self.discriminator_model()
gray_scale_images = gray_scale_image_dataset
colored_images = color_image_dataset
gen_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
dis_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
for eachEpoch in range(self.epochs):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
for i in range(20):
random.shuffle(gray_scale_image_dataset)
random.shuffle(color_image_dataset)
gray_scale_dataset_image = gray_scale_images[:self.batch_size]
print(gray_scale_dataset_image.shape)
color_dataset_image_batch = colored_images[:self.batch_size]
#lets see which images are being trained
self.draw_train_images(color_dataset_image_batch)
generated_image = generator(gray_scale_dataset_image)
real_output = discriminator(color_dataset_image_batch)
fake_output = discriminator(generated_image)
print("What Discriminator Thought about real_output = {} and fake output = {}".format(real_output[0],fake_output[0]))
gen_loss = self.generator_loss(fake_output,generated_image,color_dataset_image_batch)
dis_loss = self.discriminator_loss(fake_output,real_output)
print("GEN LOSS {} and DISC = {}".format(gen_loss[0],dis_loss[0]))
gen_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
disc_gradients = disc_tape.gradient(dis_loss,discriminator.trainable_variables)
gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
dis_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
print ("EPOCHS COMPLETED = {} ".format(eachEpoch))
self.draw_images(generator,test_image)
Это моя функция поезда, которая фактически обучает сеть ГАН.Генератор и дискриминатор - это две разные сети, которые были созданы после статьи, поэтому с этим нет проблем.Моя проблема с изображениями, которые я передаю.Я проверил изображения, оба gray_scale_image_dataset,color_image_dataset
, и они хорошо работают.Но когда я перехожу к функции self.draw_train_images
и пытаюсь нарисовать их в matplotlib, в сетке отображается только первое изображение.Вся сетка заполнена первым изображением, и только это изображение используется для обучения данных, поэтому я получаю много ошибок.Любая помощь по этому вопросу?Где я все испортил?