Как визуализировать фильтры в CNN с PyTorch - PullRequest
0 голосов
/ 09 апреля 2019

Я новичок в области глубокого обучения и Pytorch. Я хочу визуализировать мой фильтр в моей модели CNN, поэтому я хочу повторить слой в модели CNN, которую я определяю. Но я встречаю ошибку, как показано ниже.

ошибка

Объект 'CNN' не повторяется

объект CNN - моя модель

мой код итерации, как показано ниже:

for index, layer in enumerate(self.model):             
# Forward pass layer by layer
    x = layer(x)

код моей модели, как показано ниже:

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.Conv1 = nn.Sequential( # input image size (1,28,20)
            nn.Conv2d(1, 16, 5, 1, 2), # outputize (16,28,20)
            nn.ReLU(),
            nn.MaxPool2d(2),           #outputize (16,14,10)
        )
        self.Conv2 = nn.Sequential( # input ize ? (16,,14,10)
            nn.Conv2d(16, 32, 5, 1, 2),   #output size(32,14,10)
            nn.ReLU(),
            nn.MaxPool2d(2),        #output size (32,7,5)
        )
        self.fc1 = nn.Linear(32 * 7 * 5, 800) 
        self.fc2 = nn.Linear(800,500)
        self.fc3 = nn.Linear(500,10)
        #self.fc4 = nn.Linear(200,10)

    def forward(self,x):
        x = self.Conv1(x)
        x = self.Conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.dropout(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.dropout(x)
        x = F.relu(x)
        x = self.fc3(x)
        #x = F.relu(x)
        #x = self.fc4(x)
        return x

так что любой может сказать мне, как я могу решить эту проблему.

извините, мой английский плохой.

1 Ответ

0 голосов
/ 10 апреля 2019

По сути, вам необходимо получить доступ к функциям в вашей модели и сначала преобразовать эти матрицы в правильную форму, затем вы сможете визуализировать фильтры

import numpy as np
import matplotlib.pyplot as plt
from torchvision import utils

def visTensor(tensor, ch=0, allkernels=False, nrow=8, padding=1): 
    n,c,w,h = tensor.shape

    if allkernels: tensor = tensor.view(n*c, -1, w, h)
    elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)

    rows = np.min((tensor.shape[0] // nrow + 1, 64))    
    grid = utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding)
    plt.figure( figsize=(nrow,rows) )
    plt.imshow(grid.numpy().transpose((1, 2, 0)))


if __name__ == "__main__":
    layer = 1
    filter = model.features[layer].weight.data.clone()
    visTensor(filter, ch=0, allkernels=False)

    plt.axis('off')
    plt.ioff()
    plt.show()

Вы должны быть в состоянии получить визуальную сетку. enter image description here

Есть еще несколько методов визуализации, вы можете изучить их здесь

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...