Я пытаюсь обучить сеть GAN на устройстве TPU Google Cloud. Я изменил код на https://github.com/deepak112/Keras-SRGAN/blob/master/simplified/train.py, в основном перенесенный для использования API набора данных и моего собственного набора данных. Когда я запускаю свой код, он тренируется очень медленно, с низким использованием процессора и памяти на TPU - кажется, я сделал много плохих тренировок и узких мест. Кто-нибудь видит, что я делаю неправильно?
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import keras
from keras.applications.vgg19 import VGG19
import keras.backend as K
from keras.layers import add, Dense, Input, Lambda
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
import numpy as np
from tensorflow.keras.optimizers import SGD, Adam
import tensorflow as tf
import tensorflow_datasets as tfds
image_lr_shape = (64,64,3)
image_hr_shape = (256,256,3)
image_lr_dir = "gs://******/Dataset/Small/*.png"
image_hr_dir = "gs://******/Dataset/Large/*.png"
model_dir = "gs://******/Models"
epochs = 1000
batch_size = 128
def vgg_loss(y_true, y_pred):
vgg19 = VGG19(include_top=False, weights="imagenet", input_shape=image_hr_shape)
vgg19.trainable = False
for l in vgg19.layers:
l.trainable = False
loss_model = Model(inputs=vgg19.input,outputs=vgg19.get_layer("block5_conv4").output)
loss_model.trainable = False
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
def residual_block(model,kernel_size,filters,strides):
gen = model
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = add([gen, model])
return model
def up_sampling_block(model,kernel_size,filters,strides):
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = UpSampling2D(size=2)(model)
model = LeakyReLU(alpha=0.2)(model)
return model
def discriminator_block(model, filters, kernel_size, strides):
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = LeakyReLU(alpha=0.2)(model)
return model
def generator_network():
gen_input = Input(shape=image_lr_shape)
model = Conv2D(filters=64,kernel_size=9,strides=1,padding="same")(gen_input)
model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
gen_model = model
for index in range(16):
model = residual_block(model,3,64,1)
model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = add([gen_model,model])
for index in range(2):
model = up_sampling_block(model,3,256,1)
model = Conv2D(filters=3,kernel_size=9,strides=1,padding="same")(model)
model = Activation("tanh")(model)
generator_model = Model(inputs=gen_input,outputs=model)
return generator_model
def discriminator_network():
dis_input = Input(shape=image_hr_shape)
model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(dis_input)
model = LeakyReLU(alpha=0.2)(model)
model = discriminator_block(model,64,3,2)
model = discriminator_block(model,128,3,1)
model = discriminator_block(model,128,3,2)
model = discriminator_block(model,256,3,1)
model = discriminator_block(model,256,3,2)
model = discriminator_block(model,512,3,1)
model = discriminator_block(model,512,3,2)
model = Flatten()(model)
model = Dense(1024)(model)
model = LeakyReLU(alpha=0.2)(model)
model = Dense(1)(model)
model = Activation("sigmoid")(model)
discriminator_model = Model(inputs=dis_input,outputs=model)
return discriminator_model
def get_gan_network(discriminator,generator,optimizer):
discriminator.trainable=False
gan_input = Input(shape=image_lr_shape)
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs=gan_input,outputs=[x,gan_output])
gan.compile(loss=[vgg_loss,"binary_crossentropy"],loss_weights=[1.,1e-3], optimizer=optimizer)
return gan
def load_image(path):
img = tf.io.read_file(path)
img = tf.image.decode_png(img,channels=3)
img = tf.image.convert_image_dtype(img,tf.float32)
return img
#============================================WARNING: Dataset size hardcoded!
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="******",zone="******",project="******")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
print("Loading dataset...")
image_hr = tf.data.Dataset.list_files(image_hr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
image_lr = tf.data.Dataset.list_files(image_lr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
batch_count = int(6777/batch_size)
print("Compiling neural network---")
generator = generator_network()
discriminator = discriminator_network()
adam = Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-8)
generator.compile(loss=vgg_loss,optimizer=adam)
discriminator.compile(loss="binary_crossentropy",optimizer=adam)
gan = get_gan_network(discriminator,generator,adam)
print("Starting training...")
for e in range(1,epochs+1):
print("-"*15,"Epoch %d" % e, "-"*15)
for b in range(batch_count):
print("-"*10,"Batch %d" % b,"-"*10)
batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
batch_sr = generator.predict(batch_lr)
real_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
fake_y = tf.random.uniform(shape=(batch_size,1),minval=0,maxval=0.2)
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(batch_hr,real_y)
d_loss_fake = discriminator.train_on_batch(batch_sr,fake_y)
batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
gan_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
discriminator.trainable = False
loss_gan = gan.train_on_batch(batch_lr,[batch_hr,gan_y])
print("Loss d_real,Loss d_fake, Loss network")
print(d_loss_real,d_loss_fake,loss_gan)
os.makedirs(os.path.join(model_dir,str(e)))
generator.save(os.path.join(model_dir,str(e),"gen.h5"))
discriminator.save(os.path.join(model_dir,str(e),"dis.h5"))
gan.save(os.path.join(model_dir,str(e),"gan.h5"))
ОБНОВЛЕНИЕ: Я добавил больше отладочных отпечатков, он не работает на batch_hr - нет ошибки, но он зависает здесь.