Обучение сети GAN медленно на Gooogle Cloud TPU - PullRequest
0 голосов
/ 29 марта 2020

Я пытаюсь обучить сеть GAN на устройстве TPU Google Cloud. Я изменил код на https://github.com/deepak112/Keras-SRGAN/blob/master/simplified/train.py, в основном перенесенный для использования API набора данных и моего собственного набора данных. Когда я запускаю свой код, он тренируется очень медленно, с низким использованием процессора и памяти на TPU - кажется, я сделал много плохих тренировок и узких мест. Кто-нибудь видит, что я делаю неправильно?

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import keras
from keras.applications.vgg19 import VGG19
import keras.backend as K
from keras.layers import add, Dense, Input, Lambda
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
import numpy as np
from tensorflow.keras.optimizers import SGD, Adam
import tensorflow as tf
import tensorflow_datasets as tfds


image_lr_shape = (64,64,3)
image_hr_shape = (256,256,3)
image_lr_dir = "gs://******/Dataset/Small/*.png"
image_hr_dir = "gs://******/Dataset/Large/*.png"
model_dir = "gs://******/Models"
epochs = 1000
batch_size = 128

def vgg_loss(y_true, y_pred):
    vgg19 = VGG19(include_top=False, weights="imagenet", input_shape=image_hr_shape)
    vgg19.trainable = False
    for l in vgg19.layers:
        l.trainable = False
    loss_model = Model(inputs=vgg19.input,outputs=vgg19.get_layer("block5_conv4").output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

def residual_block(model,kernel_size,filters,strides):
    gen = model
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = add([gen, model])
    return model

def up_sampling_block(model,kernel_size,filters,strides):
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = UpSampling2D(size=2)(model)
    model = LeakyReLU(alpha=0.2)(model)
    return model

def discriminator_block(model, filters, kernel_size, strides):
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = LeakyReLU(alpha=0.2)(model)
    return model

def generator_network():
    gen_input = Input(shape=image_lr_shape)
    model = Conv2D(filters=64,kernel_size=9,strides=1,padding="same")(gen_input)
    model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
    gen_model = model
    for index in range(16):
        model = residual_block(model,3,64,1)
    model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = add([gen_model,model])
    for index in range(2):
        model = up_sampling_block(model,3,256,1)
    model = Conv2D(filters=3,kernel_size=9,strides=1,padding="same")(model)
    model = Activation("tanh")(model)
    generator_model = Model(inputs=gen_input,outputs=model)
    return generator_model

def discriminator_network():
    dis_input = Input(shape=image_hr_shape)
    model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(dis_input)
    model = LeakyReLU(alpha=0.2)(model)
    model = discriminator_block(model,64,3,2)
    model = discriminator_block(model,128,3,1)
    model = discriminator_block(model,128,3,2)
    model = discriminator_block(model,256,3,1)
    model = discriminator_block(model,256,3,2)
    model = discriminator_block(model,512,3,1)
    model = discriminator_block(model,512,3,2)
    model = Flatten()(model)
    model = Dense(1024)(model)
    model = LeakyReLU(alpha=0.2)(model)
    model = Dense(1)(model)
    model = Activation("sigmoid")(model)
    discriminator_model = Model(inputs=dis_input,outputs=model)
    return discriminator_model

def get_gan_network(discriminator,generator,optimizer):
    discriminator.trainable=False
    gan_input = Input(shape=image_lr_shape)
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input,outputs=[x,gan_output])
    gan.compile(loss=[vgg_loss,"binary_crossentropy"],loss_weights=[1.,1e-3], optimizer=optimizer)
    return gan

def load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img,channels=3)
    img = tf.image.convert_image_dtype(img,tf.float32)
    return img

#============================================WARNING: Dataset size hardcoded!
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="******",zone="******",project="******")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
    print("Loading dataset...")
    image_hr = tf.data.Dataset.list_files(image_hr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
    image_lr = tf.data.Dataset.list_files(image_lr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
    batch_count = int(6777/batch_size)
    print("Compiling neural network---")
    generator = generator_network()
    discriminator = discriminator_network()
    adam = Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-8)
    generator.compile(loss=vgg_loss,optimizer=adam)
    discriminator.compile(loss="binary_crossentropy",optimizer=adam)
    gan = get_gan_network(discriminator,generator,adam)
    print("Starting training...")
    for e in range(1,epochs+1):
        print("-"*15,"Epoch %d" % e, "-"*15)
        for b in range(batch_count):
            print("-"*10,"Batch %d" % b,"-"*10)
            batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
            batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
            batch_sr = generator.predict(batch_lr)
            real_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
            fake_y = tf.random.uniform(shape=(batch_size,1),minval=0,maxval=0.2)
            discriminator.trainable = True
            d_loss_real = discriminator.train_on_batch(batch_hr,real_y)
            d_loss_fake = discriminator.train_on_batch(batch_sr,fake_y)
            batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
            batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
            gan_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
            discriminator.trainable = False
            loss_gan = gan.train_on_batch(batch_lr,[batch_hr,gan_y])
        print("Loss d_real,Loss d_fake, Loss network")
        print(d_loss_real,d_loss_fake,loss_gan)
        os.makedirs(os.path.join(model_dir,str(e)))
        generator.save(os.path.join(model_dir,str(e),"gen.h5"))
        discriminator.save(os.path.join(model_dir,str(e),"dis.h5"))
        gan.save(os.path.join(model_dir,str(e),"gan.h5"))

ОБНОВЛЕНИЕ: Я добавил больше отладочных отпечатков, он не работает на batch_hr - нет ошибки, но он зависает здесь.

...