Torchscript несовместим с torch.cat для тензорных списков - PullRequest
0 голосов
/ 27 февраля 2020

Torch.cat выдает ошибку для тензорных списков при использовании в torchscript

Вот минимальный воспроизводимый пример для воспроизведения ошибки

import torch
import torch.nn as nn

"""
Smallest working bug for torch.cat torchscript
"""


class Model(nn.Module):
    """dummy model for showing error"""

    def __init__(self):
        super(Model, self).__init__()
        pass

    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
        return out


if __name__ == '__main__':
    model = Model()
    print(model())  # works
    torch.jit.script(model)  # throws error

Ожидаемый результат будет выводом torchscript для torch .кошка. Вот сообщение об ошибке:

File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
    self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError: 
Arguments for call are not valid.
The following operator variants are available:

  aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
  Keyword argument axis unknown.

  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
  Argument out not provided.

The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
              ~~~~~~~~~ <--- HERE
        return out

Пожалуйста, дайте мне знать, как исправить или обойти эту проблему.

Спасибо!

1 Ответ

1 голос
/ 27 февраля 2020

изменение axis на dim исправляет ошибку, Оригинальное решение было опубликовано здесь

...