Я пытаюсь сохранить модель Keras Tensorflow cGAN (pix2pix) как SavedModel
и делать с ней прогнозы (без TF-обслуживания на этом этапе, только путем загрузки модели непосредственно в коде).
За исключением нескольких небольших рефакторингов, реализация в основном идентична на веб-сайте Tensorflow , который сохраняет модель как checkpoints
.
Основная функция выглядит следующим образом: this:
train_dataset = (
tf.data.Dataset.list_files("./dataset/train/*.jpg")
.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
)
test_dataset = (
tf.data.Dataset.list_files("./dataset/test/*.jpg")
.map(load_image_test)
.batch(BATCH_SIZE)
)
generator = generator()
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator = discriminator()
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint = tf.train.Checkpoint(
generator=generator,
generator_optimizer=generator_optimizer,
discriminator=discriminator,
discriminator_optimizer=discriminator_optimizer,
)
train(
generator=generator,
discriminator=discriminator,
train_dataset=train_dataset,
test_dataset=test_dataset,
epochs=EPOCHS,
checkpoint=checkpoint,
checkpoint_prefix=os.path.join("./checkpoints", "ckpt"),
)
Обучение работает, как ожидается, и следующие прогнозы работают:
for input_image, target_image in test_dataset.take(1):
prediction = generator(input_image, training=False)
print(prediction[0])
tf.Tensor(
[[[-5.31874537e-01 -4.05731827e-01 -5.50406814e-01]
[-6.34294271e-01 -4.89944369e-01 -5.99278808e-01]
[-6.67676032e-01 -5.50003827e-01 -6.57217383e-01]
...
Однако при попытке сохранить модель загрузите ее и сделать прогноз, я получаю только тензор, заполненный nan
:
tf.saved_model.save(generator, "./model")
loaded = tf.saved_model.load("./model")
infer = loaded.signatures["serving_default"]
for input_image, target_image in test_dataset.take(1):
prediction = infer(input_image)[generator.output_names[0]]
print(prediction[0])
tf.Tensor(
[[[nan nan nan]
[nan nan nan]
[nan nan nan]
...
Наконец, я также попытался сохранить checkpoint
(tf.saved_model.save(checkpoint, "./model")
), но это отсутствует подпись для вывода, и я не знаю, как ее определить.
Я все еще очень плохо знаком с TF, любая помощь будет принята с благодарностью. Приветствия