Я использую следующий код для генерации квантованной модели tflite
import tensorflow as tf
def representative_dataset_gen():
for _ in range(num_calibration_steps):
# Get sample input data as a numpy array in a method of your choosing.
yield [input]
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
Но в соответствии с квантованием после обучения :
Полученная модель будетполностью квантованный, но для удобства все равно вход и выход с плавающей запятой .
Для компиляции модели tflite для Google Coral Edge TPU Мне также нужны квантованные входные и выходные данные.
В модели я вижу, что первый сетевой уровень преобразует входные данные с плавающей запятой в input_uint8
, а последний слой преобразует output_uint8
в выходные данные с плавающей запятой.Как мне отредактировать модель tflite, чтобы избавиться от первого и последнего слоев с плавающей точкой?
Я знаю, что мог бы установить тип ввода и вывода для uint8 во время преобразования, но это не совместимо ни с какими оптимизациями.Тогда единственный доступный вариант - использовать ложное квантование, что приводит к плохой модели.