Как включить поддержку Dict / OrderedDict / NamedTuple в JIT-компиляторе pytorch 1.1.0? - PullRequest
3 голосов
/ 02 июля 2019

Из релизной версии Pytorch 1.1.0. Похоже, что последний JIT-компилятор теперь поддерживает тип Dict. (Источник: https://jaxenter.com/pytorch-1-1-158332.html)

Поддержка словаря и списков в TorchScript: списки и типы словарей ведут себя как списки и словари Python.

К сожалению, я не могу найти способ заставить это улучшение работать должным образом. Следующий код представляет собой простой пример экспорта Feature Pyramid Network (FPN) в тензорную панель, использующую JIT-компилятор:

from collections import OrderedDict

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

torchWriter = SummaryWriter(log_dir=".tensorboard/example1")

m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
# get some dummy data
x = OrderedDict()
x['feat0'] = torch.rand(1, 10, 64, 64)
x['feat2'] = torch.rand(1, 20, 16, 16)
x['feat3'] = torch.rand(1, 30, 8, 8)
# compute the FPN on top of x
output = m.forward(x)
print([(k, v.shape) for k, v in output.items()])

torchWriter.add_graph(m, input_to_model=x)

Когда я его запустил, я получил следующую ошибку:

Traceback (most recent call last):
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 276, in graph
    trace, _ = torch.jit.get_trace_graph(model, args)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 231, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in forward
    in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got collections.OrderedDict

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/peng/git-drone/gate_detection/python/gate_detection/errorcase/tb.py", line 36, in <module>
    torchWriter.add_graph(m, input_to_model=x)
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 534, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 279, in graph
    _ = model(*args)  # don't catch, just print the error message
  File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

Из сообщения об ошибке видно, что поддержка еще не завершена. Могу ли я доверять основным моментам выпуска? Или я не правильно использую API?

1 Ответ

1 голос
/ 03 июля 2019

Примечания к выпуску точны, хотя и немного расплывчаты. Поддержка словаря / списка / пользовательских классов, описанная в этой ссылке (и официальные примечания к выпуску ), применима только к компилятору TorchScript (в примечаниях к выпуску есть несколько примеров кода), но SummaryWriter по умолчанию запускает трассировщик TorchScript на любом модуле, который вы передаете ему, и трассировщик поддерживает только Tensors и списки / кортежи Tensors.

Таким образом, исправление будет состоять в том, чтобы использовать компилятор TorchScript, а не трассировщик, но для этого требуется:

  1. Доступ к исходному коду
  2. Поддержка скомпилированного вывода (ScriptModule) в Tensorboard

Вы должны подать проблему для (2), и продолжается работа для исправления (1), но это не будет работать в краткосрочной перспективе для этой модели afaik .

...