INT8 квантование - PullRequest
       3

INT8 квантование

0 голосов
/ 18 января 2020

Я сохранил модель со следующими тензорами:

with connected_graph.as_default():
    print('\nSaving...')
    cwd = os.getcwd()
    path = os.path.join(cwd, 'saved_model')
    shutil.rmtree(path, ignore_errors=True)
    inputs_dict = {
        "image_tensor": tf_input
    }
    outputs_dict = {
        "detection_boxes_l1": tf_boxes_l1,
        "detection_scores_l1": tf_scores_l1,
        "detection_classes_l1": tf_classes_l1,
        "max_num_detection": tf_max_num_detection,
        "detection_boxes_l2": tf_boxes_l2,
        "detection_scores_l2": tf_scores_l2,
        "detection_classes_l2": tf_classes_l2
    }
    tf.saved_model.simple_save(
        tf_sess_main, path, inputs_dict, outputs_dict
    )
    print('Ok')

Теперь я хочу преобразовать график в RT с квантованием INT8. Код выглядит следующим образом:

dataset = tf.data.TFRecordDataset(tf_calib_data_files)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
def input_map_fn():
    return {'image_tensor:0': features}

converter = trt.TrtGraphConverter(
    input_saved_model_dir=output_saved_model_dir,
    precision_mode=trt.TrtPrecisionMode.INT8,
    is_dynamic_op=True,
    nodes_blacklist=['detection_boxes_l1:0','detection_scores_l1:0','detection_classes_l1:0','detection_boxes_l2:0','detection_scores_l2:0','detection_classes_l2:0'],
    use_calibration=True)

frozen_graph = converter.convert()

converted_graph_def = converter.calibrate(
    fetch_names=['detection_boxes_l1:0','detection_scores_l1:0','detection_classes_l1:0','detection_boxes_l2:0','detection_scores_l2:0','detection_classes_l2:0'],
    num_runs=10,
    input_map_fn=input_map_fn)

Но это выдает мне следующую ошибку:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 427, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: node 'IteratorGetNext' in input_map does not exist in graph (input_map entry: image_tensor:0->IteratorGetNext:0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 347, in calibrate
    name="")
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: node 'IteratorGetNext' in input_map does not exist in graph (input_map entry: image_tensor:0->IteratorGetNext:0)
...