Подготовка данных для обучения ГАН на ТПУ - PullRequest
0 голосов
/ 24 сентября 2019

Я хочу использовать TPUGANEstimator для обучения модифицированной версии pix2pix GAN на TPU с дополнительными условными изображениями (исходную публикацию (не-TPU) можно найти здесь: https://phillipi.github.io/pix2pix/). Для обучения генератора мне нужно скормитьэто input_image и condition_image, и для дискриминатора я передаю target_image и condition_image. Моя проблема в том, как Оценщик находит правильный ввод в модель для словаря, который я создаю в input_fn. Вот псевдокод для моего кода входной функции:

def input_fn(mode, params):
    is_train = mode == tf.estimator.ModeKeys.TRAIN

    # Yields a triplet of input_image, condition_image, target_image
    data_gen = training_generator if is_train else test_generator 

    dataset = tf.data.Dataset.from_generator(data_gen,
        ({'input': tf.float32, 'condition': tf.float32, 'target': tf.float32}),
        output_shapes=({'input': tf.TensorShape(shape), 'condition': tf.TensorShape(shape),
                        'target': tf.TensorShape(shape)}))
    if is_train:
        dataset = dataset.shuffle(buffer_size=int(config["shuffle_buffer_ratio"] * config["batch_size"]))

    dataset = dataset.prefetch(config["max_buffer"])
    dataset = dataset.batch(config["batch_size"])
    # return dataset

    iterator = dataset.make_one_shot_iterator()
    next_item = iterator.get_next()
    return next_item

Первая проблема состоит в том, чтобы определить генератор для использования двух входов из словаря:

def generator_fn(input_dict, mode='TRAIN', scope='Generator'):
    input = input_dict[‘input’]
    condition = input_dict[‘condition’]
    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        x = generator_network(x)
    return x

И вторая проблема заключается в том, как переключать последовательность между поддельными и действительными изображениями.для дискриминатора (возможно, с использованием параметра joint_train):

def discriminator_fn(image, input_dict, scope='Discriminator'):
    x = discriminator_network(x)
return x

Я определил показатель TPUGANEstimator в соответствии с документацией, но я не нашел хорошего учебника / примера о том, как создать правильную последовательность.

gan_estimator = tfgan.estimator.TPUGANEstimator(
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    generator_loss_fn=tfgan.losses.minimax_generator_loss,
    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss,
    generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
    discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
    joint_train=False,
    gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
    model_dir=config['model_dir'],
    params=params,
    use_tpu=False,
    train_batch_size=2,
    eval_batch_size=2,
    config=t_config)

while cur_step < number_of_steps:
    print("Running gan estimator: {}".format(cur_step))
    gan_estimator.train(train_input_fn, steps=cur_step)

Я высоко ценю вашу помощь вdvance.

...