как переопределить предварительно обученный модельный класс в Pytorch? - PullRequest
0 голосов
/ 03 февраля 2020

Я пытаюсь использовать предварительно обученный Re sNet (2 + 1) D [1], но так как его первый уровень использует 3 канала, и я использую только один канал, я полагаю, мне придется переопределить этот класс. Пожалуйста, посмотрите на мою попытку, я получаю сообщение об ошибке:

TypeError: _video_resnet() got multiple values for keyword argument 'stem'

[1] https://pytorch.org/docs/stable/_modules/torchvision/models/video/resnet.html#r2plus1d_18

код:

class R2Plus1dStem4IMAGES(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """
    def __init__(self):
        super(R2Plus1dStem4IMAGES, self).__init__(
            nn.Conv3d(1, 45, kernel_size=(1, 7, 7),
                      stride=(1, 2, 2), padding=(0, 3, 3),
                      bias=False),
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
                      stride=(1, 1, 1), padding=(1, 0, 0),
                      bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))



model = torchvision.models.video.r2plus1d_18(pretrained=True, stem=R2Plus1dStem4IMAGES)

model.fc = nn.Linear(model.fc.in_features, 3)
...