Сохранение модели Tensorflow cGAN (pix2pix) Keras в виде `SavedModel` - PullRequest
0 голосов
/ 01 мая 2020

Я пытаюсь сохранить модель Keras Tensorflow cGAN (pix2pix) как SavedModel и делать с ней прогнозы (без TF-обслуживания на этом этапе, только путем загрузки модели непосредственно в коде).

За исключением нескольких небольших рефакторингов, реализация в основном идентична на веб-сайте Tensorflow , который сохраняет модель как checkpoints.

Основная функция выглядит следующим образом: this:

train_dataset = (
    tf.data.Dataset.list_files("./dataset/train/*.jpg")
    .map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
)

test_dataset = (
    tf.data.Dataset.list_files("./dataset/test/*.jpg")
    .map(load_image_test)
    .batch(BATCH_SIZE)
)

generator = generator()
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator = discriminator()
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint = tf.train.Checkpoint(
    generator=generator,
    generator_optimizer=generator_optimizer,
    discriminator=discriminator,
    discriminator_optimizer=discriminator_optimizer,
)

train(
    generator=generator,
    discriminator=discriminator,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    epochs=EPOCHS,
    checkpoint=checkpoint,
    checkpoint_prefix=os.path.join("./checkpoints", "ckpt"),
)

Обучение работает, как ожидается, и следующие прогнозы работают:

for input_image, target_image in test_dataset.take(1):
    prediction = generator(input_image, training=False)
    print(prediction[0])
tf.Tensor(
[[[-5.31874537e-01 -4.05731827e-01 -5.50406814e-01]
  [-6.34294271e-01 -4.89944369e-01 -5.99278808e-01]
  [-6.67676032e-01 -5.50003827e-01 -6.57217383e-01]
  ...

Однако при попытке сохранить модель загрузите ее и сделать прогноз, я получаю только тензор, заполненный nan:

tf.saved_model.save(generator, "./model")

loaded = tf.saved_model.load("./model")
infer = loaded.signatures["serving_default"]

for input_image, target_image in test_dataset.take(1):
    prediction = infer(input_image)[generator.output_names[0]]
    print(prediction[0])
tf.Tensor(
[[[nan nan nan]
  [nan nan nan]
  [nan nan nan]
  ...

Наконец, я также попытался сохранить checkpoint (tf.saved_model.save(checkpoint, "./model")), но это отсутствует подпись для вывода, и я не знаю, как ее определить.

Я все еще очень плохо знаком с TF, любая помощь будет принята с благодарностью. Приветствия

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...