Я пытаюсь использовать некоторые GAN для создания изображений клеток. Кажется, он работает в градациях серого, но в градациях серого дает результаты, цвет которых никак не связан с набором обучающих данных. Я попытался добавить еще несколько тысяч итераций, увеличив размер сверточных блоков, но ничего. В чем причина этого? Кажется, что один канал хорош (например, синий), но другой номер.
, вот мой код dim = 1 или 3, чтобы установить черный и белый или цвет и изменить cv2.IMREAD_GRAYSCALE на cv2.COLOR_BGR2RGB для загрузки данных. Что я делаю не так?
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
# Load data
#(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
import os
import cv2
# Load data
train_images = []
files = glob.glob ("/content/train/*.png")
IMG_SIZE = 128
dim = 1
for myFile in files:
try:
img_array = cv2.imread(myFile, cv2.IMREAD_GRAYSCALE ) # convert to array #,cv2.IMREAD_GRAYSCALE cv2.COLOR_BGR2RGB
# img_array = img_array[:,:,0:1]
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE)) # resize to normalize data size
# new_array = new_array[:,:,0:1]
train_images.append(new_array) # add this to our training_data
except Exception as e: # in the interest in keeping the output clean...
pass
#except OSError as e:
# print("OSErrroBad img most likely", e, os.path.join(path,img))
#except Exception as e:
# print("general exception", e, os.path.join(path,img))
train_images = np.asarray(train_images)
# Resize and normalize
train_images = train_images.reshape(train_images.shape[0], IMG_SIZE, IMG_SIZE, dim).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
print(np.shape(train_images))
print("ok")
plt.imshow(train_images[1, :, :,0] )
plt.show()
#Make batches of images
BUFFER_SIZE = 9
BATCH_SIZE = 3
EPOCHS = 2000
noise_dim = 100
num_examples_to_generate = 1
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print(np.shape(train_dataset))
##################### Generator model and generator
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense( (IMG_SIZE//8) * (IMG_SIZE//8) *1024, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((IMG_SIZE//8, IMG_SIZE//8, 1024)))
assert model.output_shape == (None, IMG_SIZE//8, IMG_SIZE//8, 1024) # Note: None is the batch size
# model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
# assert model.output_shape == (None, IMG_SIZE//8, IMG_SIZE//8, 128)
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(512, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, IMG_SIZE//8, IMG_SIZE//8, 512)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, IMG_SIZE//4, IMG_SIZE//4, 256)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, IMG_SIZE//2, IMG_SIZE//2, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(dim, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, IMG_SIZE, IMG_SIZE, dim)
return model
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same',
input_shape=[IMG_SIZE, IMG_SIZE, dim]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(1024, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
print("images from train steps")
print(np.shape(images))
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
print(np.shape(dataset))
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
# if (epoch + 1) % 15 == 0:
# checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
from PIL import Image
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
# plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5)# , cmap='gray'
plt.axis('off')
# print(np.shape(predictions[0, :, :, :]))
# if epoch == EPOCHS:
# my_imm = np.asarray(predictions[0, :, :, :])
# im = Image.fromarray(np.uint8(my_imm* 255)).convert('RGB')
# im = Image.fromarray((my_imm * 255).astype(np.uint8)).convert('RGB'))
# im.save('image_at_epoch_{:04d}.jpg'.format(epoch))
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
# plt.show()
###################################################################3
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0])
discriminator = make_discriminator_model()
print("to the discriminator")
print(np.shape(generated_image))
decision = discriminator(generated_image)
print (decision)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# Save checkpoint
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
train(train_dataset, EPOCHS)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)