Вызов `model.predict ()` из внешней присоединенной функции - PullRequest
2 голосов
/ 12 июля 2020

Используя this в качестве ссылки, я придумал следующий код:

import tensorflow as tf
from tensorflow.keras.applications.densenet import DenseNet121
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess_input
import inspect, cv2
import numpy as np

@tf.function(input_signature=[tf.TensorSpec([None, None, 3],dtype=tf.uint8)])
def _preprocess(image_array):
    im_arr = tf.image.resize(image_array, (resize_height, resize_width))
    im_arr = densenet_preprocess_input(im_arr)
    input_batch = tf.expand_dims(im_arr, axis=0)
    return input_batch

training_model = DenseNet121(include_top=True, weights='imagenet')

#Assign resize dimensions
resize_height = tf.constant(480, dtype=tf.int64)
resize_width = tf.constant(640, dtype=tf.int64)

#Attach function to Model
training_model.preprocess = _preprocess

#Attach resize dimensions to Model
training_model.resize_height = resize_height
training_model.resize_width = resize_width

training_model.save("saved_model", overwrite=True)

, который в основном присоединяет метод под названием preprocess к tf.keras.Model определенному для DenseNet121 .

Чтобы позже я мог использовать его следующим образом, чтобы сделать прогноз:

pred_model = tf.keras.models.load_model('saved_model')

#download image
image_path = tf.keras.utils.get_file("cat.jpg", "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg")
#load and convert the image to tf.uint8 numpy array
image_array = np.array(tf.keras.preprocessing.image.load_img(path=image_path))

#call the custom function bound to the model
preprocessed_image = pred_model.preprocess(image_array)

result = pred_model.predict(preprocessed_image)
print(np.argmax(result, axis=-1), np.amax(result, axis=-1))

Мой вопрос:

Как я могу вызовите метод предсказать модель из функции предварительной обработки. Так что

preprocessed_image = pred_model.preprocess(image_array)
result = pred_model.predict(preprocessed_image)

может стать

result = pred_model.preprocess_predict(image_array)
...