Как удалить слой maxpooling из сети Pytorch - PullRequest
0 голосов
/ 05 сентября 2018

У меня "простой" Unet с кодировкой resnet на pytorch v3.1, который работает довольно хорошо:

class UNetResNet(nn.Module):
    """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.

    UNet: https://arxiv.org/abs/1505.04597
    ResNet: https://arxiv.org/abs/1512.03385

    Args:
            encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - ResNet encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.

    """

    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152:
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')

        self.pool = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1

        self.conv3 = self.encoder.layer2

        self.conv4 = self.encoder.layer3

        self.conv5 = self.encoder.layer4

        self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
                                   is_deconv)
        self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
                                   is_deconv)
        self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        pool = self.pool(conv5) # that pool layer I would like to delete
        center = self.center(pool)


        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

Удалив этот слой пула из forward, мы переходим к следующей forward функции:

def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        center = self.center(conv5) #now center connects with conv5

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

Применим только такие изменения, и мы столкнемся с такой ошибкой:

~/Desktop/ml/salt/unet_models.py in forward(self, x)
    397         center = self.center(conv5)
    398 
--> 399         dec5 = self.dec5(torch.cat([center, conv5], 1))
    400 
    401         dec4 = self.dec4(torch.cat([dec5, conv4], 1))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 4 and 8 in dimension 2 at /pytorch/torch/lib/THC/generic/THCTensorMath.cu:111

Так что мы должны как-то изменить dec4, но как?

Дополнительная информация:

    class DecoderBlockV2(nn.Module):
        def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
            super(DecoderBlockV2, self).__init__()
            self.in_channels = in_channels

            if is_deconv:
                """
                    Paramaters for Deconvolution were chosen to avoid artifacts, following
                    link https://distill.pub/2016/deconv-checkerboard/
                """

                self.block = nn.Sequential(
                    ConvRelu(in_channels, middle_channels),
                    nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                       padding=1),
            nn.BatchNorm2d(out_channels), 
                    nn.ReLU(inplace=True)
                )
            else:
                self.block = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear'),
                    ConvRelu(in_channels, middle_channels),
                    ConvRelu(middle_channels, out_channels),
                )

        def forward(self, x):
            return self.block(x)

class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x

1 Ответ

0 голосов
/ 06 сентября 2018

Ваша ошибка связана с тем, что вы пытаетесь объединить (torch.cat) center и conv5 - но размеры этих тензоров не совпадают.
Первоначально у вас были следующие пространственные размеры

  • conv5: 4x4
  • pool: 2x2
  • center: 4x4 # с повышением частоты дискретизации Таким образом, вы можете объединить center и conv5, поскольку они оба являются тензорами 4x4.
    Однако, если вы удалите слой пула, вы получите

  • conv5: 4x4

  • center: 8x8 # увеличивает частоту дискретизации больше, чем вам нужно Теперь вы не можете объединить center и conv5, поскольку их пространственные размеры не совпадают.

Что вы можете сделать?

  1. Один из вариантов - удалить center, а также пул, оставив вам

    dec5 = self.dec5(conv5)
    
  2. Другой вариант - удалить часть nn.Upsample из блока center, оставив вам

    self.block = nn.Sequential(
                    ConvRelu(in_channels, middle_channels),
                    ConvRelu(middle_channels, out_channels),
                )
    
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...