Не удается загрузить SavedModel - PullRequest
0 голосов
/ 10 февраля 2020

Я тренирую некоторые модели, используя tf.keras, и хочу сохранить обученную модель. Есть два рекомендуемых способа сделать это: файл tf SavedModel и keras .h5. Однако с SavedModel.

все очень запутано. Вот несколько коротких сценариев, воспроизводящих проблему:

import numpy as np
import tensorflow as tf

def my_network():
    backbone_model = tf.keras.applications.InceptionResNetV2(
        include_top=False, input_shape=(224, 224, 3), weights="imagenet", pooling="avg"
    )

    inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="images")
    backbone_features = backbone_model(inputs)
    pre_embeddings = tf.keras.layers.Dense(
        512,
        activation=None,
        name="pre_embeddings",
        kernel_regularizer=tf.keras.regularizers.l2(),
    )(backbone_features)

    embeddings = tf.keras.layers.Lambda(
        lambda x: tf.math.l2_normalize(x, 1, 1e-10), name="embeddings"
    )(pre_embeddings)

    probs = tf.keras.layers.Dense(
        1000,
        activation="softmax",
        name="predictions",
        kernel_regularizer=tf.keras.regularizers.l2(),
        bias_regularizer=tf.keras.regularizers.l2(),
    )(pre_embeddings)

    return tf.keras.Model(inputs, [embeddings, probs], name="my_network")


img_arr = np.random.rand(1, 224, 224, 3)

resnet_model = my_network()
emb_1, _ = resnet_model.predict(img_arr)
resnet_model.save("./resnet_model.h5")

new_model = tf.keras.models.load_model('./resnet_model.h5')
emb_2, _ = new_model.predict(img_arr)

np.testing.assert_array_almost_equal(emb_1, emb_2)

Приведенный выше сценарий будет работать без ошибок. Однако это не удается, когда я пытался сохранить в формате SavedModel (удаляя .h5 из пути модели). Модель успешно сохранена, но выдает ошибку при загрузке, и появляется сообщение об ошибке:

NotImplementedError: When subclassing the `Model` class, you should implement a `call` method.

Я сбит с толку, потому что я не использовал подклассовую модель. Как показано в сценарии, моя сеть построена только с функциональным API.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...