Я создал простой нейронный net (для классификации изображений в наборе данных CIFAR, приведенном в официальной документации pytorch 60 минут. Блиц-раздел). Теперь, после выполнения следующего преобразования:
pytorch -> ONNX -> TensorFlow -> tflite
Я использую следующий код для преобразования .pb -> .tflite
:
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file="../models/my_cifar_net.pb", input_arrays=["my_input"], output_arrays=["my_output"])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
with tf.io.gfile.GFile('../models/my_cifar_net.tflite', 'wb') as f:
f.write(tf_lite_model)
Все остальные преобразованные модели ( например, .pth, .onnx, .pb
) дают правильный вывод для данных примера, кроме .tflite
, так как я получаю следующую ошибку:
снимок ошибки
Я использую следующие версии diff. библиотеки:
- numpy == 1.18.2
- pandas == 1.0.3
- torch == 1.5.1
- torchvision == 0.6.1
- onnx == 1.7.0
- coremltools == 4.0b1
- onnxmltools == 1.7.0
- onnx-tf == 1.6 .0
- tenorflow == 2.2.0
- tenorflow-addons == 0.10.0