Ошибка при преобразовании pytorch в тензорный поток через onnx - PullRequest
0 голосов
/ 29 января 2019

У меня есть предварительно обученная модель (https://github.com/zraurum/EAST-pretrained-model) в Pytorch 1.0 (расширение .pth)

Я пытался преобразовать Pytorch 1.0 модель через onnx 1.3.0 в tensorflow 1.12

import onnx
import torch
import torch._C as _C
from onnx_tf.backend import prepare
from torch.autograd import Variable

OperatorExportTypes = _C._onnx.OperatorExportTypes

from models.east.east_model import East

trained_model = East().cpu()
trained_model.load_state_dict(torch.load('./models/east/model/epoch5300.pth', map_location='cpu'), strict=False)

dummy_input = Variable(torch.randn(1, 3, 672, 1280))    

torch.onnx.export(trained_model, dummy_input, "output/east.onnx", operator_export_type=OperatorExportTypes.RAW)

model = onnx.load('output/ep_180_sim_autoencoder.onnx')

# from onnx to tensorflow
tf_rep = prepare(model)

tf_rep.export_graph('output/ep_180_sim_autoencoder.pb')

Ошибка:

Traceback (последний вызов был последним): файл "convert_east.py", строка 22, в файле tf_rep = prepare (model)"/home/www/frompytorchtoonnx/venvt/src/onnx-tf/onnx_tf/backend.py", строка 53, в подготовительном супер (TensorflowBackend, cls) .prepare (файл модели, устройства, ** kwargs) "/ home /www / frompytorchtoonnx / venvt / lib / python3.5 / site-packages / onnx / backend / base.py ", строка 74, в подготовительном файле onnx.checker.check_model (модель)" / home / www / frompytorchtoonnx / venvt / lib/python3.5/site-packages/onnx/checker.py ", строка 82, в check_model C.check_model (model.SerializeToString ()) onnx.onnx_cpp2py_export.checker.ValidationError: Нет импорта параметров для домена 'org.pytorch.prim'

Контекст: Неверная спецификация узла: выход: «31» op_type: «Константа» атрибут {имя: «значение», целые: 1 целое: 1 тип: INTS} домен: «org.pytorch.prim»

Виртуальная среда содержит следующие библиотеки:

absl-py==0.7.0
astor==0.7.1
gast==0.2.2
grpcio==1.18.0
h5py==2.9.0
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
Markdown==3.0.1
numpy==1.16.0
onnx==1.3.0
-e git+https://github.com/onnx/onnx-tensorflow@c4a75b09e85ffb7b5afda14b64194ca972f957bf#egg=onnx_tf
Pillow==5.4.1
pkg-resources==0.0.0
protobuf==3.6.1
PyYAML==3.13
six==1.12.0
tensorboard==1.12.2
tensorflow==1.12.0
termcolor==1.1.0
torch==1.0.0
torchvision==0.2.1
typing==3.6.6
typing-extensions==3.7.2
Werkzeug==0.14.1
...