Ожидаемый 4-мерный ввод для 4-мерного веса 30 1 5 5 0, но вместо этого получил 5-мерный ввод размера [4, 2, 1, 256, 256] - PullRequest
0 голосов
/ 25 января 2020

У меня есть загрузчик данных, как показано ниже, который загружает данные с размером пакета (по умолчанию 8):

class DataLoaderStego(DataLoader):
    def __init__(self, cover_dir, stego_dir, embedding_otf=False,
                 shuffle=False, pair_constraint=False, batch_size=1,
                 transform=None, num_workers=0, pin_memory=False):
        self.pair_constraint = pair_constraint
        self.embedding_otf = embedding_otf
        if pair_constraint and batch_size % 2 == 0:
            dataset = DatasetPair(cover_dir, stego_dir, embedding_otf,
                                  transform)
            _batch_size = int(batch_size / 2)
        else:
            dataset = DatasetNoPair(cover_dir, stego_dir, embedding_otf,
                                    transform)
            _batch_size = batch_size
        if pair_constraint:
            if shuffle:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)
        else:
            sampler = RandomBalancedSampler(dataset)
        super(DataLoaderStego, self). \
            __init__(dataset, _batch_size, None, sampler,
                     None, num_workers, pin_memory=pin_memory, drop_last=True)
        self.shuffle = shuffle

    def __iter__(self):
        return DataLoaderIterWithReshape(self)
        # if self.pair_constraint:
        # return DataLoaderIterWithReshape(self)
        # else:
        #     return DataLoaderIter(self)

Этот класс создает набор данных, используя класс DatasetPair:

class DatasetPair(Dataset):
    def __init__(self, cover_dir, stego_dir, embedding_otf=False,
                 transform=None):
        self.cover_dir = cover_dir
        self.stego_dir = stego_dir
        self.cover_list = [x.split('/')[-1]
                           for x in glob(cover_dir + '/*')]
        self.transform = transform
        self.embedding_otf = embedding_otf
        assert len(self.cover_list) != 0, "cover_dir is empty"
        # stego_list = ['.'.join(x.split('/')[-1].split('.')[:-1])
        #               for x in glob(stego_dir + '/*')]


    def __getitem__(self, idx):
        idx = int(idx)
        labels = np.array([0, 1], dtype='int32')
        cover_path = os.path.join(self.cover_dir,
                                  self.cover_list[idx])
        print("cover path", cover_path)
        # cover = Image.open(cover_path)
        cover = Image.open(cover_path)
        # print("@@@", cover)
        images = np.empty((2, cover.size[0], cover.size[1], 1),
                          dtype='uint8')
        print("images", images.shape) # Print image
        images[0, :, :, 0] = np.array(cover)
        if self.embedding_otf:
            images[1, :, :, 0] = np.copy(images[0, :, :, 0])
            beta_path = os.path.join(self.stego_dir,
                                     '.'.join(self.cover_list[idx].
                                              split('.')[:-1]) + '.mat')
            beta_map = io.loadmat(beta_path)['pChange']
            rand_arr = np.random.rand(cover.size[0], cover.size[1])
            inf_map = rand_arr < (beta_map / 2.)
            images[1, np.logical_and(
                images[0, :, :, 0] != 255, inf_map), 0] += 1
            inf_map[:, :] = rand_arr > 1 - (beta_map / 2.)
            images[1, np.logical_and(
                images[0, :, :, 0] != 0, inf_map), 0] -= 1
        else:
            stego_path = os.path.join(self.stego_dir,
                                      self.cover_list[idx])
            print("stego path", stego_path)
            images[1, :, :, 0] = Image.open(stego_path)
        samples = {'images': images, 'labels': labels}
        print(images[0].shape)
        if self.transform:
            samples = self.transform(samples)
        return samples

Все мои изображения имеют размер 256 * 256, когда я печатаю cover:

@@@ <PIL.PpmImagePlugin.PpmImageFile image mode=L size=256x256 at 0x7F96631B8748>

Когда я создаю сетевую модель с этими конфигурациями:

class MyNet(nn.Module):
    def __init__(self, with_bn=False, threshold=3):
        super(YeNet, self).__init__()
        self.with_bn = with_bn
        self.preprocessing = SRM_conv2d(1, 0)
        self.TLU = nn.Hardtanh(-threshold, threshold, True)
        ...

    def forward(self, x):
        x = x.float()
        print("########", x.shape)
        x = self.preprocessing(x)
        ...

И запускаю сеть, я получаю эту ошибку:

Traceback (most recent call last):
  File "main.py", line 169, in <module>
    train(epoch)
  File "main.py", line 121, in train
    outputs = net(images)
  File "/home/emad/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/emad/myworks/stego/codes/analysis/YeNet-Pytorch/YeNet.py", line 110, in forward
    x = self.preprocessing(x)
  File "/home/emad/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/emad/myworks/stego/codes/analysis/YeNet-Pytorch/YeNet.py", line 53, in forward
    self.dilation, self.groups
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 30 1 5 5 0, but got 5-dimensional input of size [4, 2, 1, 256, 256] instead

Это мой сетевой вызов:

for batch_idx, data in enumerate(train_loader):
        images, labels = Variable(
            data['images']), Variable(data['labels'])
        if args.cuda:
            images, labels = images.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = net(images)

И это self.preprocessing функция:

class SRM_conv2d(nn.Module):
    def __init__(self, stride=1, padding=0):
        super(SRM_conv2d, self).__init__()
        self.in_channels = 1
        self.out_channels = 30
        self.kernel_size = (5, 5)
        if isinstance(stride, int):
            self.stride = (stride, stride)
        else:
            self.stride = stride
        if isinstance(padding, int):
            self.padding = (padding, padding)
        else:
            self.padding = padding
        self.dilation = (1, 1)
        self.transpose = False
        self.output_padding = (0,)
        self.groups = 1
        print(1)
        self.weight = Parameter(torch.Tensor(30, 1, 5, 5),
                                requires_grad=True)
        print(2)
        self.bias = Parameter(torch.Tensor(30),
                              requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.numpy()[:] = SRM_npy
        self.bias.data.zero_()

    def forward(self, input):
        # print(
        #     "3333", input, self.weight, self.bias,
        #     self.stride, self.padding,
        #     self.dilation, self.groups
        # )
        print("###")
        return F.conv2d(
            input, self.weight, self.bias,
            self.stride, self.padding,
            self.dilation, self.groups
        )
        print("$$$")

Я хочу включить пакетный ввод для моей сети. Но также, когда я запускаю свой код с batch_size=1, получаю ту же ошибку.

Дополнительная информация : Это три моих преобразования, которые я использую:

class ToTensor(object):
    def __call__(self, samples):
        images, labels = samples['images'], samples['labels']
        print("EEE", images.shape)
        images = images.transpose((0, 3, 1, 2))
        # images = (images.transpose((0,3,1,2)).astype('float32') / 127.5) - 1.
        return {'images': torch.from_numpy(images),
                'labels': torch.from_numpy(labels).long()}


class RandomRot(object):
    def __call__(self, samples):
        images = samples['images']
        rot = random.randint(0, 3)
        return {'images': np.rot90(images, rot, axes=[1, 2]).copy(),
                'labels': samples['labels']}


class RandomFlip(object):
    def __call__(self, samples):
        if random.random() < 0.5:
            images = samples['images']
            return {'images': np.flip(images, axis=2).copy(),
                    'labels': samples['labels']}
        else:
            return samples
...