У меня есть предварительно обученная модель под названием model.pth
.Я загружаю его и конвертирую в ONNX:
class TempModel(nn.Module):
def dummyFunc():
print("dummy")
model = TempModel()
state_dict = torch.load("/pathToModel.model.pth")
model.load_state_dict(state_dict, strict=False)
dummy_input = torch.randn(10, 3, 256, 256)
torch.onnx.export(model, dummy_input, "myModel.onnx")
Когда я запускаю его, я получаю эту ошибку:
Traceback (most recent call last):
File "onnxconvert.py", line 48, in <module>
torch.onnx.export(model, dummy_input, "myModel.onnx")
File "/Users/sidyakinian/anaconda2/lib/python2.7/site- packages/torch/onnx/__init__.py", line 27, in export
return utils.export(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 104, in export
operator_export_type=operator_export_type)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 281, in _export
example_outputs, propagate)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 252, in forward
out = self.inner(*trace_inputs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 487, in __call__
result = self._slow_forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 85, in forward
raise NotImplementedError
NotImplementedError
Если я изменяю model.load_state_dict(state_dict, strict=False)
на model.load_state_dict(state_dict)
, я получаюследующая ошибка:
Traceback (most recent call last):
File "onnxconvert.py", line 45, in <module>
model.load_state_dict(state_dict)
File "/Users/myUserName/anaconda2/lib/python2.7/site- packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TempModel:
Unexpected key(s) in state_dict: "fc.weight", "fc.bias", "head_0.conv_0.bias", "head_0.conv_0.weight_orig", "head_0.conv_0.weight_u", "head_0.conv_0.weight_v", "head_0.conv_1.bias", "head_0.conv_1.weight_orig", "head_0.conv_1.weight_u", "head_0.conv_1.weight_v"...
Имеется более 100 неожиданных ключей, я просто вырезал некоторые из них.
Похоже, я должен реализовать метод forward
в TempModel
, номодель имеет более 100 параметров, и я не создал ее, поэтому я не уверен, как именно это сделать.
Что мне следует сделать здесь, чтобы успешно загрузить и экспортировать модель?Пожалуйста, помогите !!