ERFNet (pytorch) в ONNX - PullRequest
       24

ERFNet (pytorch) в ONNX

0 голосов
/ 21 сентября 2019

Когда я конвертирую модель ERFNet (pytorch) в ONNX.

(Данная модель относится к https://github.com/wvangansbeke/LaneDetection_End2End)
Мой код выглядит следующим образом.

from torch.autograd import Variable
import torch.onnx
import torchvision
from Networks.LSQ_layer import Net
from Networks.utils import define_args, save_weightmap, first_run, \
                           mkdir_if_missing, Logger, define_init_weights,\
                           define_scheduler, define_optim, AverageMeter

dummy_input = Variable(torch.randn(1,3,256,512)).cuda()
parser = define_args()
args = parser.parse_known_args()[0]  

model = Net(args)
define_init_weights(model, args.weight_init)
checkpoint = torch.load("model_best_epoch_204.pth.tar")
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()

torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)

Появляется эта ошибка.ValueError: Автоматическое вложение не знает, как обрабатывать входной объект типа int.Допустимые типы: Тензор или списки / кортежи из них

Traceback (most recent call last):
  File "/tmp/pycharm_project_633/venv/Scripts/darknet2onnx.py", line 22, in <module>
    torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 27, in export
    return utils.export(*args, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 104, in export
    operator_export_type=operator_export_type)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 281, in _export
    example_outputs, propagate)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
    trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 252, in forward
    out = self.inner(*trace_inputs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/tmp/pycharm_project_633/venv/Scripts/Networks/LSQ_layer.py", line 295, in forward
    shared_encoder, output = self.net(input, end_to_end*self.pretrained)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 464, in _slow_forward
    input_vars = tuple(torch.autograd.function._iter_tensors(input)) #lulu 
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 284, in _iter
    for var in _iter(o):
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 293, in _iter
    if condition_msg else ""))
ValueError: Auto nesting doesn't know how to process an input object of type int. Accepted types: Tensors, or lists/tuples of them

После отладки я обнаружил, что проблема в файле anaconda3 / lib / python3.6 / site-packages / torch / autograd / function.py, line293, в _iter

def _iter_filter(condition, allow_unknown=False, condition_msg=None,
                 conversion=None):#lulu change allow_unknown=False to True
    def _iter(obj):
        if conversion is not None:
            obj = conversion(obj)
        if condition(obj):
            yield obj
        elif obj is None:
            return
        elif isinstance(obj, (list, tuple)):
            for o in obj:
                for var in _iter(o):
                    yield var
        elif allow_unknown:
            yield obj
        else:
            raise ValueError("Auto nesting doesn't know how to process "
                             "an input object of type " + torch.typename(obj) +
                             (". Accepted types: " + condition_msg +
                              ", or lists/tuples of them"
                              if condition_msg else ""))

    return _iter

В этой функции

for var in _iter(o):
    yield var

Она в цикле дает o = (int) 0 и сообщает ValueError.

Почему этоноль будет генерироваться?Как я могу это исправить?Если вы хотите больше деталей, пожалуйста, прокомментируйте.Я отвечу вам, как только смогу.

...