Я обучил агента ProGAN, используя это переопределение PyTorch, и я сохранил агент как .pth
. Теперь мне нужно преобразовать агент в формат .onnx
, который я делаю, используя этот scipt:
from torch.autograd import Variable
import torch.onnx
import torchvision
import torch
device = torch.device("cuda")
dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)
torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
. После запуска я получаю ошибку AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
(полное приглашение ниже). Насколько я понял, проблема в том, что преобразование агента в .onnx требует больше информации. Я что-то упустил?
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
10 state_dict = torch.load("GAN_agent.pth", map_location = device)
11
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
146 operator_export_type, opset_version, _retain_param_name,
147 do_constant_folding, example_outputs,
--> 148 strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
149
150
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
64 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
65 example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
67
68
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
414 example_outputs, propagate,
415 _retain_param_name, do_constant_folding,
--> 416 fixed_batch_size=fixed_batch_size)
417
418 # TODO: Don't allocate a in-memory string for the protobuf
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
277 model.graph, tuple(in_vars), False, propagate)
278 else:
--> 279 graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
280 state_dict = _unique_state_dict(model)
281 params = list(state_dict.values())
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
226 # A basic sanity check: make sure the state_dict keys are the same
227 # before and after running the model. Fail fast!
--> 228 orig_state_dict_keys = _unique_state_dict(model).keys()
229
230 # By default, training=False, which is good because running a model in
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
283 # id(v) doesn't work with it. So we always get the Parameter or Buffer
284 # as values, and deduplicate the params using Parameters and Buffers
--> 285 state_dict = module.state_dict(keep_vars=True)
286 filtered_dict = type(state_dict)()
287 seen_ids = set()
AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'