Мне нужно запустить на c ++ предварительно обученную модель pytorch nn (обученную на python), чтобы делать прогнозы.
Для этого я следую инструкциям по загрузке модели pytorch в c ++здесь: https://pytorch.org/tutorials/advanced/cpp_export.html
Но когда я пытаюсь получить torch.jit.ScriptModule через трассировку, как указано в первом шаге урока:
traced_script_module =
torch.jit.trace(model, (input_tensor_1, input_tensor_2))
Вместо возврата факела.jit.ScriptModule возвращает функцию:
print(type(traced_script_module))
<type 'function'>
, которая при запуске:
traced_script_module.save("model.pt")
приводит к следующей ошибке:
Traceback (most recent call last):
File "serialize_model.py", line 60, in <module>
traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'
Есть идеи, что я делаю не так?