Я пытаюсь использовать этот пример кода с сайта PyTorch для преобразования модели Python для использования в PyTorch c ++ api (LibTorch).
Converting to Torch Script via Tracing
To convert a PyTorch model to Torch Script via tracing, you must pass an instance of your model along with an example input to the torch.jit.trace function. This will produce a torch.jit.ScriptModule object with the trace of your model evaluation embedded in the module’s forward method:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
Этот пример работает нормальнои сохраняет файл, как и ожидалось.Когда я переключаюсь на эту модель:
model = models.segmentation.deeplabv3_resnet101(pretrained=True)
Это дает мне следующую ошибку:
File "convert.py", line 14, in <module>
traced_script_module = torch.jit.trace(model, example)
File "C:\Python37\lib\site-packages\torch\jit\__init__.py", line 636, in trace
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])
Я предполагаю, что это потому, что формат example
неверен, но как я могу получить правильный?
Исходя из комментариев ниже, мой новый код:
import torch
import torchvision
from torchvision import models
model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
И теперь я получаю сообщение об ошибке:
File "convert.py", line 15, in <module>
traced_script_module = torch.jit.trace(model, example)
File "C:\Python37\lib\site-packages\torch\jit\__init__.py", line 636, in trace
var_lookup_fn, _force_outplace)
RuntimeError: Only tensors and (possibly nested) tuples of tensors are supported as inputs or outputs of traced functions (toIValue at C:\a\w\1\s\windows\pytorch\torch/csrc/jit/pybind_utils.h:91)
(no backtrace available)