«Не удалось импортировать исключение метаграфа» при попытке преобразовать тензорный поток в trt с помощью tf-trt - PullRequest
0 голосов
/ 06 января 2020

Я пытаюсь преобразовать модель тензорного потока в trt с помощью конвертера tf-trt, но получаю многочисленные ошибки, используя официальный пример кода, подробный пример кода здесь здесь

def build_model():
    model = keras.Sequential([
            layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
            layers.Dense(64, activation='relu'),
            layers.Dense(1)
  ])

    optimizer = tf.keras.optimizers.RMSprop(0.001)

    model.compile(loss='mse',
                optimizer=optimizer,
                metrics=['mae', 'mse'])
    return model

model = build_model()

  # freeze model 

frozen_path = "/root/workspace/tftrt_test/frozen-auto-mpg-test.pb"

output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            K.get_session(), K.get_session().graph.as_graph_def(), ["dense_2/BiasAdd"])

with tf.gfile.FastGFile(frozen_path, 'wb') as f:
    f.write(output_graph_def.SerializeToString())


# convert model to trt
with tf.device('/GPU:2'):
    with tf.gfile.GFile(frozen_path, 'rb') as f:
        frozen_graph = tf.GraphDef()
        frozen_graph.ParseFromString(f.read())
    # Now you can create a TensorRT inference graph from your
    # frozen graph:
    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph
#         nodes_blacklist=['logits', 'classes']) #output nodes
    )
    trt_graph = converter.convert()

Я получил ниже исключения:

INFO:tensorflow:Linked TensorRT version: (6, 0, 1)
INFO:tensorflow:Loaded TensorRT version: (6, 0, 1)
INFO:tensorflow:Running against TensorRT version 6.0.1
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-14-c26f6f7bc759> in <module>
      9 #         nodes_blacklist=['logits', 'classes']) #output nodes
     10     )
---> 11     trt_graph = converter.convert()
     12 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in convert(self)
    296     assert not self._converted
    297     if self._input_graph_def:
--> 298       self._convert_graph_def()
    299     else:
    300       self._convert_saved_model()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _convert_graph_def(self)
    224     self._add_nodes_blacklist()
    225 
--> 226     self._run_conversion()
    227 
    228   def _collections_to_keep(self, collection_keys):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _run_conversion(self)
    202         grappler_session_config,
    203         self._grappler_meta_graph_def,
--> 204         graph_id=b"tf_graph")
    205     self._converted = True
    206 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py in OptimizeGraph(config_proto, metagraph, verbose, graph_id, cluster)
     39                                           config_proto.SerializeToString(),
     40                                           metagraph.SerializeToString(),
---> 41                                           verbose, graph_id)
     42   if ret_from_swig is None:
     43     return None

InvalidArgumentError: Failed to import metagraph, check error log for more info.

Я также пытался с SavedModel:

save_model_dir = "/root/workspace/tftrt_test/saved_model_dir"
with tf.device('/GPU:2'):
    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(save_model_dir)
    builder.add_meta_graph_and_variables(sess=K.get_session(),
                                         tags=[tf.saved_model.tag_constants.SERVING],
                                         signature_def_map=None)
    builder.save()

# print out
# INFO:tensorflow:No assets to save.
# INFO:tensorflow:No assets to write.
# INFO:tensorflow:SavedModel written to: /root/workspace/tftrt_test/saved_model_dir/saved_model.pb

И при попытке загрузить модель


with tf.device("/device:GPU:2"):
    converter = trt.TrtGraphConverter(
        input_saved_model_dir=save_model_dir
    )
    converter.convert()

Возвращается:

INFO:tensorflow:Linked TensorRT version: (6, 0, 1)
INFO:tensorflow:Loaded TensorRT version: (6, 0, 1)
INFO:tensorflow:Running against TensorRT version 6.0.1
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py:245: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
INFO:tensorflow:Froze 0 variables.
INFO:tensorflow:Converted 0 variables to const ops.
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-21-bde49fe68811> in <module>
      6     #     maximum_cached_engines=100)
      7     )
----> 8     converter.convert()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in convert(self)
    298       self._convert_graph_def()
    299     else:
--> 300       self._convert_saved_model()
    301     return self._converted_graph_def
    302 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _convert_saved_model(self)
    285       # TODO(laigd): maybe add back AssetFileDef.
    286 
--> 287     self._run_conversion()
    288 
    289   def convert(self):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _run_conversion(self)
    202         grappler_session_config,
    203         self._grappler_meta_graph_def,
--> 204         graph_id=b"tf_graph")
    205     self._converted = True
    206 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py in OptimizeGraph(config_proto, metagraph, verbose, graph_id, cluster)
     39                                           config_proto.SerializeToString(),
     40                                           metagraph.SerializeToString(),
---> 41                                           verbose, graph_id)
     42   if ret_from_swig is None:
     43     return None

InvalidArgumentError: Failed to import metagraph, check error log for more info.

Это выполняется с -

  • официальный контейнер nvidia: nvcr.io/nvidia/tensorflow:19.10-py3
  • python версия: 3.6
  • тензор потока версия: 1.14.0
  • cuda 10
...