Как использовать tflite_convert для квантования сети, обученной с помощью tf.estimator.Estimator? - PullRequest
0 голосов
/ 26 июня 2018

tflite_convert - это скрипт на Python, используемый для вызова TOCO (TensorFlow Lite Optimizing Converter) для преобразования файлов из форматов Tensorflow в файлы, совместимые с tflite.

Я пытаюсь сгенерировать квантованную модель TFlite, начиная с сети, в которой я тренировался с Estimator. Учебный код довольно прост, и я добавил необходимые изменения для тонкой настройки модели в соответствии с требованиями Руководство по квантованию с фиксированной точкой :

def input_fn(mode, num_classes, batch_size=1):
  #[...]
  return {'images': images}, labels

def model_fn(features, labels, num_classes, mode):
  images = features['images']
  with tf.contrib.slim.arg_scope(net_arg_scope()):
    logits, end_points = build_net(...)

  if FLAGS.with_quantization:
    tf.logging.info("Applying quantization to the graph.")
    if mode == tf.estimator.ModeKeys.EVAL:
      tf.contrib.quantize.create_eval_graph()

  tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well

  if FLAGS.with_quantization:
    tf.logging.info("Applying quantization to the graph.")
    if mode == tf.estimator.ModeKeys.TRAIN:
      tf.contrib.quantize.create_training_graph()

  # Configure the training op, etc [...]
  return tf.estimator.EstimatorSpec(...)

def main(unused_argv):
  regex = FINETUNE_LAYER_RE if not FLAGS.with_quantization else '^((?!_quant).)*$'
  ws_settings = tf.estimator.WarmStartSettings(FLAGS.pretrained_checkpoint, regex)

  # Create the Estimator
  estimator = tf.estimator.Estimator(
    model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
    model_dir=FLAGS.model_dir,
    #config=run_config,
    warm_start_from=ws_settings)

  # Set up input functions for training and evaluation
  train_input_fn = lambda : input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, FLAGS.batch_size)
  eval_input_fn = lambda : input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, FLAGS.batch_size)

  #[...]

  train_spec = tf.estimator.TrainSpec(...)
  eval_spec = tf.estimator.EvalSpec(...)
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

Первая проблема, с которой я столкнулся, заключается в том, что невозможно просто продолжить обучение с использованием последней контрольной точки после добавления операций квантования. Это потому, что квантование добавляет дополнительные переменные, которые не будут найдены в контрольной точке. Я решил, что написал спецификацию «горячего старта», которая отфильтровывает все новые переменные по имени и использует в качестве контрольной точки «горячего старта» последнюю контрольную точку из обучения.

Теперь я хочу сгенерировать график оценки для сохранения (со связанными переменными), чтобы затем передать его в TOCO через скрипт tflite_convert. Я попытался преобразовать один из SavedModel s, экспортируемых после каждой оценки, но возникает следующая ошибка:

Array conv0_bn / FusedBatchNorm, который является входом для оператора Relu создание выходного массива cell_stem_0 / Relu, не хватает данных min / max, что необходимо для квантования. Либо цель не квантована выходной формат или изменить входной график, чтобы он содержал мин / макс информацию или передайте --default_ranges_min = и --default_ranges_max = если вы не заботитесь о точности результатов. Прервано (ядро сброшено)

Я не знаю, как получить правильную SavedModel или пару GraphDef + контрольных точек (хотя SavedModel предпочтительнее) Кто-нибудь пытался квантовать модель оценки? Как вы генерируете квантованный график оценки?

...