Размер вывода слишком мал для SpatialAveragePooling в Unet - PullRequest
0 голосов
/ 11 сентября 2018

Попробовал использовать Resnext в качестве кодировщика в Unet с здесь , но продолжаю получать RuntimeError: Заданный размер ввода: (4320x4x4).Расчетный выходной размер: (4320x-6x-6).Размер вывода слишком мал: /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63

(входное изображение 128 * 128, размер пакета 32)

class UNetResNext(nn.Module):
    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
    if encoder_depth == 34:
        self.encoder = resnext34()
        bottom_channel_nr = 512
    elif encoder_depth == 101:
        self.encoder = resnext101()
        bottom_channel_nr = 2048
    elif encoder_depth == 152:
        self.encoder = resnext152()
        bottom_channel_nr = 2048

    else:
        raise NotImplementedError('only 34, 101, 152 version of Resnext are implemented')

    self.pool = nn.MaxPool2d(2, 2)

    self.relu = nn.ReLU(inplace=True)

    self.conv1 = nn.Sequential(self.encoder.conv1,
                               self.encoder.bn1,
                               self.encoder.relu,
                               self.pool)

    self.conv2 = self.encoder.layer1
    self.conv3 = self.encoder.layer2
    self.conv4 = self.encoder.layer3
    self.conv5 = self.encoder.layer4
    self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

    self.dec5 =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
    self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
    self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
    self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                               is_deconv)
    self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
    self.dec0 = ConvRelu(num_filters, num_filters)
    self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
    conv1 = self.conv1(x)
    conv2 = self.conv2(conv1)
    conv3 = self.conv3(conv2)
    conv4 = self.conv4(conv3)
    conv5 = self.conv5(conv4)
    center = self.center(conv5)
    dec5 = self.dec5(torch.cat([center, conv5], 1))
    dec4 = self.dec4(torch.cat([dec5, conv4], 1))
    dec3 = self.dec3(torch.cat([dec4, conv3], 1))
    dec2 = self.dec2(torch.cat([dec3, conv2], 1))
    dec1 = self.dec1(dec2)
    dec0 = self.dec0(dec1)

    return self.final(F.dropout2d(dec0, p=self.dropout_2d))
...