import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data import torchvision.models as models import torchvision.datasets as dset import torchvision.transforms as transforms from torch.autograd import Variable from torchvision.models.vgg import model_urls from torchviz import make_dot batch_size = 3 learning_rate =0.0002 epoch = 50 resnet = models.resnet50(pretrained=True) print resnet make_dot(resnet)
Я хочу визуализировать resnet из моделей pytorch.Как мне это сделать?Я пытался использовать torchviz, но выдает ошибку:
resnet
torchviz
'ResNet' object has no attribute 'grad_fn'
Вы можете взглянуть на PyTorchViz (https://github.com/szagoruyko/pytorchviz), "Небольшой пакет для создания визуализаций графиков и трассировок PyTorch."
make_dot ожидает переменную (т. Е. Тензор с grad_fn), а не саму модель. try:
make_dot
grad_fn
x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False) out = resnet(x) make_dot(out) # plot graph of variable, not of a nn.Module