Как квантовать входы и выходы оптимизированной модели tflite - PullRequest
1 голос
/ 02 июля 2019

Я использую следующий код для генерации квантованной модели tflite

import tensorflow as tf

def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    # Get sample input data as a numpy array in a method of your choosing.
    yield [input]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

Но в соответствии с квантованием после обучения :

Полученная модель будетполностью квантованный, но для удобства все равно вход и выход с плавающей запятой .

Для компиляции модели tflite для Google Coral Edge TPU Мне также нужны квантованные входные и выходные данные.

В модели я вижу, что первый сетевой уровень преобразует входные данные с плавающей запятой в input_uint8, а последний слой преобразует output_uint8 в выходные данные с плавающей запятой.Как мне отредактировать модель tflite, чтобы избавиться от первого и последнего слоев с плавающей точкой?

Я знаю, что мог бы установить тип ввода и вывода для uint8 во время преобразования, но это не совместимо ни с какими оптимизациями.Тогда единственный доступный вариант - использовать ложное квантование, что приводит к плохой модели.

1 Ответ

1 голос
/ 08 июля 2019

Вы можете избежать с плавающей точкой на int8 и int8, чтобы с плавающей точкой "квант / деквант", установив inference_input_type и inference_output_type (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/lite.py#L460-L476) в int8.

...