Как использовать сохраненную модель тензорного потока с помощью `save_model.simple_save`? - PullRequest
0 голосов
/ 23 октября 2019

Модель была сохранена с использованием этой функции:

import tensorflow as tf
from tensorflow.python.estimator.export import export as export_helpers
import keras.backend as K

def save_for_serving(self):
        K.clear_session()  # to reset input tensor names
        model = self.build_model(serving=True)
        model.load_weights(self._train_config['model_checkpoint_path'])

        K.set_learning_phase(0)
        with K.get_session() as sess:
            tf.saved_model.simple_save(
                sess,
                export_helpers.get_timestamped_export_dir(self._output_model_dir),
                inputs=dict(zip([t.name.split(':')[0] for t in model.model.input], model.model.input)),
                outputs=dict(zip([t.name.split('/')[0] for t in model.model.outputs], model.model.outputs))
            )

И у меня есть следующая структура каталогов:

│   ├── 1571292052
│   │   ├── saved_model.pb
│   │   └── variables
│   │       ├── variables.data-00000-of-00001
│   │       └── variables.index

Как я могу использовать это для прогнозирования вывода для новых случаев? У меня есть pandas dataframe со всеми входными переменными и функцией предварительной обработки, но я не знаю, как сделать вывод о модели?

Я начал как ниже, но не знаю, как пройти предварительно обработаннуюряды панд в загруженной модели для получения выходных данных:

import tensorflow as tf
import keras.backend as K

with K.get_session() as sess:
    i = tf.saved_model.load(sess,tags={'serve'}, export_dir='./exp_model/1571292052/')

>>> print(type(i))

<class 'tensorflow.core.protobuf.meta_graph_pb2.MetaGraphDef'>

>>> print(dir(i))

['ByteSize', 'Clear', 'ClearExtension', 'ClearField', 'CollectionDefEntry', 'CopyFrom', 'DESCRIPTOR', 'DiscardUnknownFields', 'Extensions', 'FindInitializationErrors', 'FromString', 'HasExtension', 'HasField', 'IsInitialized', 'ListFields', 'MergeFrom', 'MergeFromString', 'MetaInfoDef', 'ParseFromString', 'RegisterExtension', 'SerializePartialToString', 'SerializeToString', 'SetInParent', 'SignatureDefEntry', 'UnknownFields', 'WhichOneof', '_CheckCalledFromGeneratedFile', '_SetListener', '__class__', '__deepcopy__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__unicode__', '_extensions_by_name', '_extensions_by_number', '_tf_api_names', '_tf_api_names_v1', 'asset_file_def', 'collection_def', 'graph_def', 'meta_info_def', 'object_graph_def', 'saver_def', 'signature_def']

И, на выходе:

>>> saved_model_cli show --dir ./exp_model/1571292052 --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['name_counts'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: name_counts:0
    .
    .
    .
    .
  The given SavedModel SignatureDef contains the following output(s):
    outputs['o1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: o1/Sigmoid:0
    outputs['o2'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: o2/BiasAdd:0
  Method name is: tensorflow/serving/predict
...