pytorch torch.jit.trace возвращает функцию вместо torch.jit.ScriptModule - PullRequest
0 голосов
/ 12 февраля 2019

Мне нужно запустить на c ++ предварительно обученную модель pytorch nn (обученную на python), чтобы делать прогнозы.

Для этого я следую инструкциям по загрузке модели pytorch в c ++здесь: https://pytorch.org/tutorials/advanced/cpp_export.html

Но когда я пытаюсь получить torch.jit.ScriptModule через трассировку, как указано в первом шаге урока:

    traced_script_module =
        torch.jit.trace(model, (input_tensor_1, input_tensor_2))

Вместо возврата факела.jit.ScriptModule возвращает функцию:

    print(type(traced_script_module))
    <type 'function'>

, которая при запуске:

    traced_script_module.save("model.pt")

приводит к следующей ошибке:

Traceback (most recent call last):
  File "serialize_model.py", line 60, in <module>
    traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'

Есть идеи, что я делаю не так?

1 Ответ

0 голосов
/ 13 февраля 2019

Спасибо, что спросили Jatentaki .Я использовал PyTorch 0.4 в Python, и когда я обновился до 1.0, он работал.

...