Как использовать PNASNet5 в качестве кодировщика в Unet в pytorch - PullRequest
0 голосов
/ 08 сентября 2018

Я хочу использовать PNASNet5Large в качестве кодера для моего Unet. Вот мой неправильный подход к PNASNet5Large, но он работает для resnet:

class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152: #this works
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 777: #coded version for the pnasnet
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320 #this unknown for me as well


        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4
        self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

        self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        center = self.center(conv5)
        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

1) Как узнать, сколько нижних каналов имеет pnasnet. Это заканчивается следующим образом:

...
 self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
                            in_channels_right=4320, out_channels_right=864)
        self.relu = nn.ReLU()
        self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
        self.dropout = nn.Dropout(0.5)
        self.last_linear = nn.Linear(4320, num_classes)

Является ли 4320 ответом или нет, in_channels_left и out_channels_left - что-то новое для меня

2) В Resnet есть 4 больших слоя, которые я использую, и кодировщики в моей арке Unet, как получить аналогичный слой из pnasnet

Я использую Pytorch 3.1, и это ссылка на каталог Pnasnet

3) AttributeError: у объекта 'PNASNet5Large' нет атрибута 'conv1' - поэтому он также не имеет conv1

UPD: попробовал что-то подобное, но не получилось

класс UNetPNASNet (nn.Module): def init (self, encoder_depth, num_classes, num_filters = 32, dropout_2d = 0.2, pretrained = False, is_deconv = False): супер (). INIT () self.num_classes = num_classes self.dropout_2d = dropout_2d self.encoder = PNASNet5Large () bottom_channel_nr = 4320 self.center = DecoderCenter (bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, False)

        self.dec5  =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4  = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3  = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2  = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
        self.dec1  = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0  = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
        features = self.encoder.features(x)
        relued_features = self.encoder.relu(features)
        avg_pooled_features = self.encoder.avg_pool(relued_features)
        center = self.center(avg_pooled_features)
        dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
        dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
        dec3 = self.dec3(torch.cat([dec4, features], 1))
        dec2 = self.dec2(dec3)
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

RuntimeError: Заданный размер ввода: (4320x4x4). Расчетный выходной размер: (4320x-6x-6). Размер вывода слишком мал в /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63

1 Ответ

0 голосов
/ 09 сентября 2018

Так что вы хотите использовать PNASNetLarge вместо o ResNets в качестве кодера в вашей UNet архитектуре. Давайте посмотрим, как используются ResNets. В вашем __init__:

self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
                           self.encoder.bn1,
                           self.encoder.relu,
                           self.pool)

self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4

Таким образом, вы используете ResNets до layer4, что является последним блоком перед средним пулированием, размеры, которые вы используете для реснета, равны после среднего пула, поэтому я предполагаю, что self.encoder.avgpool отсутствует после self.conv5 = self.encoder.layer4. Форвард ResNet в torchvision.models выглядит следующим образом:

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

Полагаю, вы хотите принять аналогичное решение для PNASNet5Large (используйте архитектуру вплоть до среднего уровня пула).

1) Чтобы узнать, сколько каналов имеет ваш PNASNet5Large, вам нужно посмотреть на размер выходного тензора после усредненного пула, например, добавив в него фиктивный тензор. Также обратите внимание, что хотя ResNet обычно используются с размером ввода (batch_size, 3, 224, 224), PNASNetLarge использует (batch_size, 3, 331, 331).

m = PNASNet5Large()
x1 = torch.randn(1, 3, 331, 331)
m.avg_pool(m.features(x1)).size()
torch.Size([1, 4320, 1, 1])

Поэтому да, bottom_channel_nr=4320 для вашей PNASNet.

2) Поскольку архитектура совершенно иная, вам нужно изменить __init__ и forward вашего UNet. Если вы решите использовать PNASNet, я предлагаю вам создать новый класс:

class UNetPNASNet(nn.Module):
    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                     pretrained=False, is_deconv=False):
            super().__init__()
            self.num_classes = num_classes
            self.dropout_2d = dropout_2d
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320
            self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

            self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
            self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
            self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                       is_deconv)
            self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0 = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

        def forward(self, x):
            features = self.encoder.features(x)
            relued_features = self.encoder.relu(features)
            avg_pooled_features = self.encoder.avg_pool(relued_features)
            center = self.center(avg_pooled_features)
            dec5 = self.dec5(torch.cat([center, conv5], 1))
            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))

3) PNASNet5Large действительно не имеет атрибута conv1. Вы можете проверить это по

'conv1' in list(m.modules())
False
...