Пользовательский модуль Pytorch, использующий существующий модуль CNN - PullRequest
1 голос
/ 27 февраля 2020

Я хочу получить доступ и редактировать отдельные модули в модуле torchvision и настроить вход.

Я знаю, что вы можете редактировать подмодули следующим образом:

import torchvision
resnet18 = torchvision.models.resnet18()
print(resnet18._modules['conv1'])
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

Но я хочу создать Пользовательский Net (nn.Module) вид класса, чтобы я мог добавить дополнительные слои позже:

class Sonar(resnet18):
    pass

Выдает ошибку:

----> 1 class Sonar(resnet18):
      2     pass

/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in __init__(self, block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
    142         self.relu = nn.ReLU(inplace=True)
    143         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
--> 144         self.layer1 = self._make_layer(block, 64, layers[0])
    145         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
    146                                        dilate=replace_stride_with_dilation[0])

/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py in _make_layer(self, block, planes, blocks, stride, dilate)
    176             self.dilation *= stride
    177             stride = 1
--> 178         if stride != 1 or self.inplanes != planes * block.expansion:
    179             downsample = nn.Sequential(
    180                 conv1x1(self.inplanes, planes * block.expansion, stride),

AttributeError: 'str' object has no attribute 'expansion'

Повторная попытка с Ale xNet

alexnet = torchvision.models.AlexNet()
class Sonar(alexnet):
    pass

Выдает ошибку:

      1 alexnet = torchvision.models.AlexNet()
----> 2 class Sonar(alexnet):
      3     pass

TypeError: __init__() takes from 1 to 2 positional arguments but 4 were given

Ответы [ 2 ]

3 голосов
/ 27 февраля 2020

Следующее должно работать хорошо:

import torch
import torchvision

class Sonar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ins = torchvision.models.resnet18(pretrained=True)
        self.fc1 = torch.nn.Linear(1000, 1) #adding layers
    def forward(self, x):
        out = self.ins(x)
        out = self.fc1(out)
        return out

def run():
    return Sonar()

net = run()
print(net(torch.ones(1,3,224,224))) #testing
1 голос
/ 27 февраля 2020

Можете ли вы импортировать файл python и редактировать его локально?

...