NotImplementedError при преобразовании модели PyTorch в ONNX - PullRequest
0 голосов
/ 28 апреля 2019

У меня есть предварительно обученная модель под названием model.pth.Я загружаю его и конвертирую в ONNX:

class TempModel(nn.Module):
    def dummyFunc():
        print("dummy")

model = TempModel()
state_dict = torch.load("/pathToModel.model.pth")
model.load_state_dict(state_dict, strict=False)

dummy_input = torch.randn(10, 3, 256, 256)
torch.onnx.export(model, dummy_input, "myModel.onnx")

Когда я запускаю его, я получаю эту ошибку:

Traceback (most recent call last):
File "onnxconvert.py", line 48, in <module>
  torch.onnx.export(model, dummy_input, "myModel.onnx")
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-  packages/torch/onnx/__init__.py", line 27, in export
  return utils.export(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 104, in export
  operator_export_type=operator_export_type)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 281, in _export
  example_outputs, propagate)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
  graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
  trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
  return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
  result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 252, in forward
  out = self.inner(*trace_inputs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 487, in __call__
  result = self._slow_forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
  result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 85, in forward
  raise NotImplementedError
NotImplementedError

Если я изменяю model.load_state_dict(state_dict, strict=False) на model.load_state_dict(state_dict), я получаюследующая ошибка:

Traceback (most recent call last):
File "onnxconvert.py", line 45, in <module>
  model.load_state_dict(state_dict)
File "/Users/myUserName/anaconda2/lib/python2.7/site-   packages/torch/nn/modules/module.py", line 769, in load_state_dict
  self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TempModel:
Unexpected key(s) in state_dict: "fc.weight", "fc.bias",    "head_0.conv_0.bias", "head_0.conv_0.weight_orig", "head_0.conv_0.weight_u", "head_0.conv_0.weight_v", "head_0.conv_1.bias", "head_0.conv_1.weight_orig", "head_0.conv_1.weight_u", "head_0.conv_1.weight_v"...

Имеется более 100 неожиданных ключей, я просто вырезал некоторые из них.

Похоже, я должен реализовать метод forward в TempModel, номодель имеет более 100 параметров, и я не создал ее, поэтому я не уверен, как именно это сделать.

Что мне следует сделать здесь, чтобы успешно загрузить и экспортировать модель?Пожалуйста, помогите !!

...