Как перебрать слои в Pytorch - PullRequest
0 голосов
/ 15 января 2019

Допустим, у меня есть объект сетевой модели с именем m. Теперь у меня нет предварительной информации о количестве слоев в этой сети. Как создать цикл для перебора его слоя? Я ищу что-то вроде:

Weight=[]
for layer in m._modules:
    Weight.append(layer.weight)

Ответы [ 3 ]

0 голосов
/ 15 января 2019

Вы можете просто получить его, используя model.named_parameters(), который вернет генератор, с которым вы можете перебрать и получить тензоры, его имя и т. Д.

Вот код для повторной тренировки модели:

In [106]: resnet = torchvision.models.resnet101(pretrained=True)

In [107]: for name, param in resnet.named_parameters(): 
     ...:     print(name, param.shape) 

что бы вывести

conv1.weight torch.Size([64, 3, 7, 7])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
........
........ and so on

Вы можете найти обсуждение этой темы в как управлять параметрами слоя по его именам /

0 голосов
/ 28 апреля 2019

Если m ваш модуль, то вы можете сделать:

for layer in m.children():
    weights = list(layer.parameters())
0 голосов
/ 15 января 2019

Допустим, у вас есть следующая нейронная сеть.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # define the forward function 
        return x

Теперь давайте напечатаем размер весовых параметров, связанных с каждым NN-слоем.

model = Net()
for name, param in model.named_parameters():
    print(name, param.size())

выход

conv1.weight torch.Size([6, 1, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])

Надеюсь, вы сможете расширить пример, чтобы удовлетворить ваши потребности.

...