Я сохранил модель со следующими тензорами:
with connected_graph.as_default():
print('\nSaving...')
cwd = os.getcwd()
path = os.path.join(cwd, 'saved_model')
shutil.rmtree(path, ignore_errors=True)
inputs_dict = {
"image_tensor": tf_input
}
outputs_dict = {
"detection_boxes_l1": tf_boxes_l1,
"detection_scores_l1": tf_scores_l1,
"detection_classes_l1": tf_classes_l1,
"max_num_detection": tf_max_num_detection,
"detection_boxes_l2": tf_boxes_l2,
"detection_scores_l2": tf_scores_l2,
"detection_classes_l2": tf_classes_l2
}
tf.saved_model.simple_save(
tf_sess_main, path, inputs_dict, outputs_dict
)
print('Ok')
Теперь я хочу преобразовать график в RT с квантованием INT8. Код выглядит следующим образом:
dataset = tf.data.TFRecordDataset(tf_calib_data_files)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
def input_map_fn():
return {'image_tensor:0': features}
converter = trt.TrtGraphConverter(
input_saved_model_dir=output_saved_model_dir,
precision_mode=trt.TrtPrecisionMode.INT8,
is_dynamic_op=True,
nodes_blacklist=['detection_boxes_l1:0','detection_scores_l1:0','detection_classes_l1:0','detection_boxes_l2:0','detection_scores_l2:0','detection_classes_l2:0'],
use_calibration=True)
frozen_graph = converter.convert()
converted_graph_def = converter.calibrate(
fetch_names=['detection_boxes_l1:0','detection_scores_l1:0','detection_classes_l1:0','detection_boxes_l2:0','detection_scores_l2:0','detection_classes_l2:0'],
num_runs=10,
input_map_fn=input_map_fn)
Но это выдает мне следующую ошибку:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 427, in import_graph_def
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: node 'IteratorGetNext' in input_map does not exist in graph (input_map entry: image_tensor:0->IteratorGetNext:0)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 4, in <module>
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 347, in calibrate
name="")
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 431, in import_graph_def
raise ValueError(str(e))
ValueError: node 'IteratorGetNext' in input_map does not exist in graph (input_map entry: image_tensor:0->IteratorGetNext:0)