Показать модель PyTorch с несколькими выходами, используя torchviz make_dots - PullRequest
0 голосов
/ 06 февраля 2020

У меня есть модель с несколькими выходами, 4, если быть точным:

 def forward(self, x):
      outputs = []
      for conv, act in zip(self.Convolutions, self.Activations):
           y = conv(x)
           outputs.append(act(y))
      return outputs

Я хотел отобразить ее, используя make_dot из torchviz:

 from torchviz import make_dot
 generator = ...
 batch = next(iter(generator))
 input, output = batch["input"].to(device, dtype=torch.float), batch["output"].to(device, dtype=torch.float)
 dot = make_dot(model(input), params=dict(model.named_parameters()))

Но я получить следующую ошибку:

 File "/opt/local/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torchviz/dot.py", line 37, in make_dot
 output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
 AttributeError: 'list' object has no attribute 'grad_fn'

Очевидно, что список не имеет функции grad_fn, но в соответствии с этим обсуждением я могу вернуть список выходных данных.

Что я делаю не так?

1 Ответ

1 голос
/ 07 февраля 2020

Модель может вернуть список, но make_dot хочет Tensor. Если выходные компоненты имеют одинаковую форму, я предлагаю использовать torch.cat.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...