Я экспортировал свою модель в ONNX через:
# Export the model
torch_out = torch.onnx._export(learn.model, # model being run
x, # model input (or a tuple for multiple inputs)
EXPORT_PATH + "mnist.onnx", # where to save the model (can be a file or file-like object)
export_params=True) # store the trained parameter weights inside the model file
А сейчас я пытаюсь преобразовать модель в файл Tensorflow Lite, чтобы сделать вывод на Android. К сожалению, поддержка PyTorch / Caffe2 практически отсутствует или слишком сложна для Android, но Tensorflow выглядит намного проще.
Документация по ONNX для Tflite довольно проста в этом.
Я пытался экспортировать в прото Tensorflow GraphDef через:
tf_rep.export_graph(EXPORT_PATH + 'mnist-test/mnist-tf-export.pb')
А потом работает toco
:
toco \
--graph_def_file=mnist-tf-export.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=0 \
--output_arrays=add_10 \
--input_shapes=1,3,28,28 \
--output_file=mnist.tflite`
Когда я делаю это, я получаю следующую ошибку:
File "anaconda3/lib/python3.6/site-packages/tensorflow/lite/python/convert.py", line 172, in toco_convert_protos
"TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr))
tensorflow.lite.python.convert.ConverterError: TOCO failed. See console for info.
2018-11-06 16:28:33.864889: I tensorflow/lite/toco/import_tensorflow.cc:1268] Converting unsupported operation: PyFunc
2018-11-06 16:28:33.874130: F tensorflow/lite/toco/import_tensorflow.cc:114] Check failed: attr.value_case() == AttrValue::kType (1 vs. 6)
Более того, даже когда я запускаю команду, я не знаю, что указать для input_arrays или output_arrays, поскольку модель изначально была построена в PyTorch.
Кто-нибудь успешно конвертировал свою модель ONNX в TFlite?
Вот файл ONNX, который я пытаюсь преобразовать: https://drive.google.com/file/d/1sM4RpeBVqPNw1WeCROpKLdzbSJPWSK79/view?usp=sharing
Дополнительная информация
- Python 3.6.6 :: Anaconda custom (64-bit)
- onnx. версия = '1.3.0'
- tf. версия = '1.13.0-dev20181106'
- факел. версия = '1.0.0.dev20181029'