Я пытаюсь преобразовать модель тензорного потока в trt с помощью конвертера tf-trt, но получаю многочисленные ошибки, используя официальный пример кода, подробный пример кода здесь здесь
def build_model():
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
layers.Dense(64, activation='relu'),
layers.Dense(1)
])
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse'])
return model
model = build_model()
# freeze model
frozen_path = "/root/workspace/tftrt_test/frozen-auto-mpg-test.pb"
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
K.get_session(), K.get_session().graph.as_graph_def(), ["dense_2/BiasAdd"])
with tf.gfile.FastGFile(frozen_path, 'wb') as f:
f.write(output_graph_def.SerializeToString())
# convert model to trt
with tf.device('/GPU:2'):
with tf.gfile.GFile(frozen_path, 'rb') as f:
frozen_graph = tf.GraphDef()
frozen_graph.ParseFromString(f.read())
# Now you can create a TensorRT inference graph from your
# frozen graph:
converter = trt.TrtGraphConverter(
input_graph_def=frozen_graph
# nodes_blacklist=['logits', 'classes']) #output nodes
)
trt_graph = converter.convert()
Я получил ниже исключения:
INFO:tensorflow:Linked TensorRT version: (6, 0, 1)
INFO:tensorflow:Loaded TensorRT version: (6, 0, 1)
INFO:tensorflow:Running against TensorRT version 6.0.1
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-14-c26f6f7bc759> in <module>
9 # nodes_blacklist=['logits', 'classes']) #output nodes
10 )
---> 11 trt_graph = converter.convert()
12
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in convert(self)
296 assert not self._converted
297 if self._input_graph_def:
--> 298 self._convert_graph_def()
299 else:
300 self._convert_saved_model()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _convert_graph_def(self)
224 self._add_nodes_blacklist()
225
--> 226 self._run_conversion()
227
228 def _collections_to_keep(self, collection_keys):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _run_conversion(self)
202 grappler_session_config,
203 self._grappler_meta_graph_def,
--> 204 graph_id=b"tf_graph")
205 self._converted = True
206
/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py in OptimizeGraph(config_proto, metagraph, verbose, graph_id, cluster)
39 config_proto.SerializeToString(),
40 metagraph.SerializeToString(),
---> 41 verbose, graph_id)
42 if ret_from_swig is None:
43 return None
InvalidArgumentError: Failed to import metagraph, check error log for more info.
Я также пытался с SavedModel
:
save_model_dir = "/root/workspace/tftrt_test/saved_model_dir"
with tf.device('/GPU:2'):
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(save_model_dir)
builder.add_meta_graph_and_variables(sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map=None)
builder.save()
# print out
# INFO:tensorflow:No assets to save.
# INFO:tensorflow:No assets to write.
# INFO:tensorflow:SavedModel written to: /root/workspace/tftrt_test/saved_model_dir/saved_model.pb
И при попытке загрузить модель
with tf.device("/device:GPU:2"):
converter = trt.TrtGraphConverter(
input_saved_model_dir=save_model_dir
)
converter.convert()
Возвращается:
INFO:tensorflow:Linked TensorRT version: (6, 0, 1)
INFO:tensorflow:Loaded TensorRT version: (6, 0, 1)
INFO:tensorflow:Running against TensorRT version 6.0.1
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py:245: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
INFO:tensorflow:Froze 0 variables.
INFO:tensorflow:Converted 0 variables to const ops.
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-21-bde49fe68811> in <module>
6 # maximum_cached_engines=100)
7 )
----> 8 converter.convert()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in convert(self)
298 self._convert_graph_def()
299 else:
--> 300 self._convert_saved_model()
301 return self._converted_graph_def
302
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _convert_saved_model(self)
285 # TODO(laigd): maybe add back AssetFileDef.
286
--> 287 self._run_conversion()
288
289 def convert(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _run_conversion(self)
202 grappler_session_config,
203 self._grappler_meta_graph_def,
--> 204 graph_id=b"tf_graph")
205 self._converted = True
206
/usr/local/lib/python3.6/dist-packages/tensorflow/python/grappler/tf_optimizer.py in OptimizeGraph(config_proto, metagraph, verbose, graph_id, cluster)
39 config_proto.SerializeToString(),
40 metagraph.SerializeToString(),
---> 41 verbose, graph_id)
42 if ret_from_swig is None:
43 return None
InvalidArgumentError: Failed to import metagraph, check error log for more info.
Это выполняется с -
- официальный контейнер nvidia: nvcr.io/nvidia/tensorflow:19.10-py3
- python версия: 3.6
- тензор потока версия: 1.14.0
- cuda 10