RuntimeError: Заданные группы = 1, вес размера [64, 3, 3, 3], ожидаемый вход [4, 5000, 5000, 3] будет иметь 3 канала, но вместо этого получит 5000 каналов - PullRequest
0 голосов
/ 27 июня 2019

Итак, у меня есть модель U-Net, и я подаю изображения 5000x5000x3 в модель и получаю ошибку выше.

Итак, вот моя модель.

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNeT(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=True)
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out


Я пытался сделать модель (input.unsqueeze_ (0)), но я получил другую ошибку.

1 Ответ

0 голосов
/ 27 июня 2019

Порядок размеров в pytorch отличается от того, что вы ожидаете. Ваш входной тензор имеет shape из 4x5000x5000x3, который вы интерпретируете как пакет размером 4, с изображениями 5000x5000 пикселей, каждый пиксель имеет 3 канала. То есть ваши размеры batch - height - width - channel.

Однако, Pytorch ожидает, что тензорные размеры будут в другом порядке: batch - channel - height - width. То есть размер channel должен предшествовать пространственным измерениям ширины и высоты.

Вам нужно permute размеры вашего входного тензора, чтобы решить вашу проблему:

model(inputs.permute(0, 3, 1, 2))

Для получения дополнительной информации см. Документацию nn.Conv2d.

...