Преобразование модели Keras с TFLiteConverter в квантованную версию tflite приводит к ошибке NOTYPE - PullRequest
0 голосов
/ 22 января 2020

При преобразовании и выполнении 8-битного квантования модели keras я столкнулся со странной ошибкой, которая не произошла для наборов данных изображения.

import tensorflow.python.keras.backend as K
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model
import numpy as np

x_train = np.array([[0.6171875  0.59791666],[0.6171875  0.59791666],[0.6171875  0.59791666]])
y_train = np.array([[0.6171875  0.59791666],[0.6171875  0.59791666],[0.6171875  0.59791666]])


def representative_dataset_gen():
    for i in range(1):
        # Get sample input data as a numpy array in a method of your choosing.
        sample = np.array([0.5,0.6])
        sample = np.expand_dims(sample, axis=0)
        yield [sample]



model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(2,)),
  tf.keras.layers.Dense(12, activation='relu'),
  tf.keras.layers.Dense(2, activation='softmax')
])


model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32
converter.representative_dataset = representative_dataset_gen

tflite_quant_model = converter.convert()

Это приводит к ошибке

ValueError: Cannot set tensor: Got value of type NOTYPE but expected type FLOAT32 for input 1, name: dense_1_input

Эта процедура работала при использовании данных изображения, но теперь это происходит. Пробовал разные версии TF, включая ночные TF2.1.

1 Ответ

0 голосов
/ 22 января 2020

Очевидно, что проблема связана с типом данных входного тензора, который по умолчанию был Float64, а не ожидаемым Float32. Поскольку tflite не знает о Float64, он воспринимает его как NOTYPE, что сбивает с толку.

Поддерживаемые типы TF Lite

Приведение к float32 решает проблему

sample = sample.astype(np.float32)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...