Изменение функции forward в модели vgg - PullRequest
0 голосов
/ 26 июня 2019

Мне нужно изменить существующий метод пересылки в VGG16, чтобы он мог проходить через два классификатора и возвращать значение

Я попытался создать пользовательский метод пересылки вручную и переопределить существующий метод, но получаю следующую ошибку

vgg.forward = forward

forward () отсутствует 1 обязательный позиционный аргумент: 'x'

Моя пользовательская функция пересылки

def forward(self,x):
    x = self.features(x)
    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    y = self.classifier_2(x)
    return x,y

Я изменил vgg16_bn по умолчанию с одним дополнительным классификатором как

vgg = models.vgg16_bn()
final_in_features = vgg.classifier[6].in_features
mod_classifier = list(vgg.classifier.children())[:-1]
mod_classifier.extend([nn.Linear(final_in_features, 10)])
vgg.add_module('classifier_2',vgg.classifier)

Моя модель выглядит так после добавления вышеуказанного классификатора

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
  (classifier_2): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )

Предполагается, что результаты моих сверточных слоев будут проходить через два отдельных слоя FFN. Так как мне изменить мой прямой проход

1 Ответ

1 голос
/ 26 июня 2019

Я думаю, что лучший способ достичь того, чего вы хотите, - это создать новую модель, расширяющую nn.Module.Я бы сделал что-то вроде:

from torchvision import models
from torch import nn

class MyVgg (nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        vgg = models.vgg16_bn(pretrained=True)

        # Here you get the bottleneck/feature extractor
        self.vgg_feature_extractor = nn.Sequential(*list(vgg.children())[:-1])

        # Now you can include your classifiers
        self.classifier1 = nn.Sequential(layers1)
        self.classifier2 = nn.Sequential(layers2)

    # Set your own forward pass
    def forward(self, img, extra_info=None):

        x = self.vgg_convs(img)
        x = x.view(x.size(0), -1)
        x1 = self.classifier1(x)
        x2 = self.classifier2(x)

        return x1, x2

Дайте мне знать, помогло ли это вам.Удачи.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...