Как визуализировать сеть в Pytorch? - PullRequest
0 голосов
/ 23 сентября 2018
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

Я хочу визуализировать resnet из моделей pytorch.Как мне это сделать?Я пытался использовать torchviz, но выдает ошибку:

'ResNet' object has no attribute 'grad_fn'

Ответы [ 2 ]

0 голосов
/ 21 января 2019

Вы можете взглянуть на PyTorchViz (https://github.com/szagoruyko/pytorchviz), "Небольшой пакет для создания визуализаций графиков и трассировок PyTorch."

Example PyTorchViz visualization

0 голосов
/ 26 сентября 2018

make_dot ожидает переменную (т. Е. Тензор с grad_fn), а не саму модель.
try:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...