Встраивание функции предварительной обработки в модель tf.keras для обслуживания - PullRequest
0 голосов
/ 12 апреля 2020

Я пытаюсь встроить простую функцию предварительной обработки изображений в уже обученную модель tf.keras. Это полезная функция, поскольку она может помочь нам сократить объем стандартного кода, необходимого при использовании любой модели для служебных целей. Благодаря этой возможности вы получаете гораздо больше гибкости и модульности для вашей модели.

Итак, после обучения моей модели я сначала определяю функцию предварительной обработки, например:

def preprocess_image_cv2(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = cv2.resize(img, (28, 28)).astype("float32")
    img = img / 255
    img = np.expand_dims(img, 0)
    img = tf.convert_to_tensor(img)
    return img

Тогда я используя его для создания другого класса модели вместе с обученной моделью -

# Define the model for predcition purpose
class ExportModel(tf.keras.Model):
    def __init__(self, preproc_func, model):
        super().__init__(self)
        self.preproc_func = preproc_func
        self.model = model

    @tf.function
    def my_serve(self, image_path):
        print("Inside")
        preprocessed_image = self.preproc_func(image_path) # Preprocessing
        probabilities = self.model(preprocessed_image, training=False) # Model prediction
        class_id = tf.argmax(probabilities[0], axis=-1) # Postprocessing
        return {"class_index": class_id}

Затем я могу выполнить вывод для образца изображения с такой настройкой:

# Now initialize a dummy model and fill its parameters with that of
# the model we trained
restored_model = get_training_model()
restored_model.set_weights(apparel_model.get_weights())

# Now use this model, preprocessing function, and the same image
# for checking if everything is working
serving_model = ExportModel(preprocess_image_cv2, restored_model)
class_index = serving_model.my_serve("sample_image.png")
CLASSES[class_index["class_index"].numpy()] # prints Dress

Но я не могу экспортировать эту модель для сервировки. Я делаю следующее для экспорта -

# Make sure we are *not* letting the model to train
tf.keras.backend.set_learning_phase(0)

# Serialize model
export_path = "model_preprocessing_func"
tf.saved_model.save(serving_model, export_path, signatures={"serving_default": serving_model.my_serve})

Это дает -

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-97-9e2616e04da9> in <module>()
      1 export_path = "model_preprocessing_func"
----> 2 tf.saved_model.save(serving_model, export_path, signatures={"serving_default": serving_model.my_serve})

2 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    949 
    950   _, exported_graph, object_saver, asset_info = _build_meta_graph(
--> 951       obj, export_dir, signatures, options, meta_graph_def)
    952   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
    953 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def)
   1009 
   1010   signatures, wrapped_functions = (
-> 1011       signature_serialization.canonicalize_signatures(signatures))
   1012   signature_serialization.validate_saveable_view(checkpoint_graph_view)
   1013   signature_map = signature_serialization.create_signature_map(signatures)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py in canonicalize_signatures(signatures)
    110           ("Expected a TensorFlow function to generate a signature for, but "
    111            "got {}. Only `tf.functions` with an input signature or "
--> 112            "concrete functions can be used as a signature.").format(function))
    113 
    114     wrapped_functions[original_function] = signature_function = (

ValueError: Expected a TensorFlow function to generate a signature for, but got <tensorflow.python.eager.def_function.Function object at 0x7fd5b646ea58>. Only `tf.functions` with an input signature or concrete functions can be used as a signature.

Я могу интерпретировать последнюю часть ошибки, но я не могу понять, какие шаги следует Я беру, чтобы решить это. Можно воспроизвести проблему с этим Colab Notebook . Помощь приветствуется.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...