Я хочу использовать TPUGANEstimator для обучения модифицированной версии pix2pix GAN на TPU с дополнительными условными изображениями (исходную публикацию (не-TPU) можно найти здесь: https://phillipi.github.io/pix2pix/). Для обучения генератора мне нужно скормитьэто input_image и condition_image, и для дискриминатора я передаю target_image и condition_image. Моя проблема в том, как Оценщик находит правильный ввод в модель для словаря, который я создаю в input_fn. Вот псевдокод для моего кода входной функции:
def input_fn(mode, params):
is_train = mode == tf.estimator.ModeKeys.TRAIN
# Yields a triplet of input_image, condition_image, target_image
data_gen = training_generator if is_train else test_generator
dataset = tf.data.Dataset.from_generator(data_gen,
({'input': tf.float32, 'condition': tf.float32, 'target': tf.float32}),
output_shapes=({'input': tf.TensorShape(shape), 'condition': tf.TensorShape(shape),
'target': tf.TensorShape(shape)}))
if is_train:
dataset = dataset.shuffle(buffer_size=int(config["shuffle_buffer_ratio"] * config["batch_size"]))
dataset = dataset.prefetch(config["max_buffer"])
dataset = dataset.batch(config["batch_size"])
# return dataset
iterator = dataset.make_one_shot_iterator()
next_item = iterator.get_next()
return next_item
Первая проблема состоит в том, чтобы определить генератор для использования двух входов из словаря:
def generator_fn(input_dict, mode='TRAIN', scope='Generator'):
input = input_dict[‘input’]
condition = input_dict[‘condition’]
with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
x = generator_network(x)
return x
И вторая проблема заключается в том, как переключать последовательность между поддельными и действительными изображениями.для дискриминатора (возможно, с использованием параметра joint_train):
def discriminator_fn(image, input_dict, scope='Discriminator'):
x = discriminator_network(x)
return x
Я определил показатель TPUGANEstimator в соответствии с документацией, но я не нашел хорошего учебника / примера о том, как создать правильную последовательность.
gan_estimator = tfgan.estimator.TPUGANEstimator(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=tfgan.losses.minimax_generator_loss,
discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss,
generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
joint_train=False,
gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
model_dir=config['model_dir'],
params=params,
use_tpu=False,
train_batch_size=2,
eval_batch_size=2,
config=t_config)
while cur_step < number_of_steps:
print("Running gan estimator: {}".format(cur_step))
gan_estimator.train(train_input_fn, steps=cur_step)
Я высоко ценю вашу помощь вdvance.