Как решить RuntimeError при использовании torch.utils.tensorboard для добавления графика - PullRequest
0 голосов
/ 05 февраля 2020

Я пытаюсь использовать тензорную доску для визуализации моей модели pytorch и сталкиваюсь с проблемой. Форма входного тензора - (-1, 1, 20, 15), а форма выходного тензора - (-1, 6). Моя модель объединяет список из 5 сверточных сетей.

пакетов:

  • python: 3.7.6
  • pytorch: 1.4.0
  • тензорная доска: 2.1.0

Модель pytorch выглядит следующим образом:

import torch
from torch import nn
from torch.nn import functional as F
class MyModel(nn.Module):
    """example"""

    def __init__(self, nchunks=[2, 5, 3, 2, 3], resp_size=6):
        super().__init__()
        self.nchunks = nchunks
        self.conv = [nn.Conv2d(1, 2, (2, x)) for x in nchunks]
        self.pool = nn.Sequential(
            nn.AdaptiveMaxPool1d(output_size=10), nn.Flatten(start_dim=1)
        )
        self.bn = nn.BatchNorm1d(100)
        self.fc1 = nn.Linear(100, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, resp_size)

    def forward(self, x):
        xi = torch.split(x, self.nchunks, dim=3)
        xi = [f(subx.float()).view(-1, 2, 19) for f, subx in zip(self.conv, xi)]
        xi = [self.pool(subx) for subx in xi]
        xi = torch.cat(xi, dim=1)
        xi = self.bn(xi)
        xi = F.relu(self.fc1(xi))
        xi = F.relu(self.fc2(xi))
        xi = self.fc3(xi)
        return xi

Вот код для составителя резюме тензорной доски:

from torch.utils.tensorboard import SummaryWriter
x = torch.rand((5,1,20,15))
model = MyModel()
writer = SummaryWriter('logs')
writer.add_graph(model, x)

Возвращается такая ошибка:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
(1,1,.,.) =
 -0.2108 -0.4986
 -0.4009 -0.1910

(2,1,.,.) =
  0.2383 -0.4147
  0.2642  0.0456
[ torch.FloatTensor{2,1,2,2} ]

Я полагаю, что у модели есть некоторые проблемы, но я не уверен, что произойдет.

Этот аналог github проблема не имеет отношения к моей проблеме, потому что я не использую несколько графических процессоров.

1 Ответ

0 голосов
/ 13 февраля 2020

Я решил проблему, заменив

[nn.Conv2d(1, 2, (2, x)) for x in nchunks]

на

nn.ModuleList([nn.Conv2d(1, 2, (2, x)) for x in nchunks])
...