проблема преобразования: pytorch -> ONNX -> TensorFlow -> tflite - PullRequest
0 голосов
/ 13 июля 2020

Я создал простой нейронный net (для классификации изображений в наборе данных CIFAR, приведенном в официальной документации pytorch 60 минут. Блиц-раздел). Теперь, после выполнения следующего преобразования:

pytorch -> ONNX -> TensorFlow -> tflite

Я использую следующий код для преобразования .pb -> .tflite:

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file="../models/my_cifar_net.pb", input_arrays=["my_input"], output_arrays=["my_output"])

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()

with tf.io.gfile.GFile('../models/my_cifar_net.tflite', 'wb') as f:
    f.write(tf_lite_model)

Все остальные преобразованные модели ( например, .pth, .onnx, .pb) дают правильный вывод для данных примера, кроме .tflite, так как я получаю следующую ошибку:

снимок ошибки

Я использую следующие версии diff. библиотеки:

  • numpy == 1.18.2
  • pandas == 1.0.3
  • torch == 1.5.1
  • torchvision == 0.6.1
  • onnx == 1.7.0
  • coremltools == 4.0b1
  • onnxmltools == 1.7.0
  • onnx-tf == 1.6 .0
  • tenorflow == 2.2.0
  • tenorflow-addons == 0.10.0
...