Я хотел бы преобразовать модель (например, Mobil enet V2) из pytorch в tflite, чтобы запустить ее на мобильном устройстве.
Кому-нибудь удалось это сделать?
Все, что я обнаружил, это метод, использующий ONNX для преобразования модели в промежуточное состояние. Однако это, похоже, не работает должным образом, так как Tensorflow ожидает порядок каналов NHW C, тогда как onnx и pytorch работают с порядком каналов NCHW.
На github есть обсуждение , однако в моем случае преобразование работало без жалоб до «замороженной модели графа тензорного потока», после попытки преобразовать модель в tflite, он жалуется на неправильный порядок каналов ...
Вот мой код до сих пор :
import torch
import torch.onnx
import onnx
from onnx_tf.backend import prepare
# Create random input
input_data = torch.randn(1,3,224,224)
# Create network
model = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)
model.eval()
# Forward Pass
output = model(input_data)
# Export model to onnx
filename_onnx = "mobilenet_v2.onnx"
filename_tf = "mobilenet_v2.pb"
torch.onnx.export(model, input_data, filename_onnx)
# Export model to tensorflow
onnx_model = onnx.load(filename_onnx)
tf_rep = prepare(onnx_model)
tf_rep.export_graph(filename_tf)
Все работает без ошибок до этого момента (игнорируя многие предупреждения tf). Затем я ищу имена входных и выходных тензоров, используя netron («input.1» и «473»).
Наконец, я применяю свой обычный tf-график к tf-lite сценарий преобразования из bash:
tflite_convert \
--output_file=mobilenet_v2.tflite \
--graph_def_file=mobilenet_v2.pb \
--input_arrays=input.1 \
--output_arrays=473
Моя конфигурация:
torch 1.6.0.dev20200508 (needs pytorch-nightly to work with mobilenet V2 from torch.hub)
tensorflow-gpu 1.14.0
onnx 1.6.0
onnx-tf 1.5.0
Вот точное сообщение об ошибке, которое я получаю от tflite
:
Unexpected value for attribute 'data_format'. Expected 'NHWC'
Fatal Python error: Aborted
ОБНОВЛЕНИЕ :
Обновление моей конфигурации:
torch 1.6.0.dev20200508
tensorflow-gpu 2.2.0
onnx 1.7.0
onnx-tf 1.5.0
с использованием
tflite_convert \
--output_file=mobilenet_v2.tflite \
--graph_def_file=mobilenet_v2.pb \
--input_arrays=input.1 \
--output_arrays=473 \
--enable_v1_converter # <-- needed for conversion of frozen graphs
, что приводит к другой ошибке:
Exception: <unknown>:0: error: loc("convolution"): 'tf.Conv2D' op is neither a custom op nor a flex op
Обновление :
Вот модель mobil enet v2, загруженная через netron:
Здесь - ссылка gdrive на мои преобразованные файлы onnx и pb