Пользовательский класс набора данных Pytorch, выдающий неправильный вывод - PullRequest
0 голосов
/ 19 февраля 2020

Я пытаюсь использовать этот класс, который я построил для набора данных, но он говорит, что это должен быть PIL или ndarray. Я не совсем уверен, что с ним не так. Вот класс, который я использую

class RotateDataset(Dataset):
    def __init__(self, image_list, size,transform = None):
        self.image_list = image_list
        self.size = size
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):

        img = cv2.imread(self.image_list[idx])
        image_height, image_width = img.shape[:2]
        print("ID: ", idx)
        if idx % 2 == 0:
            label = 0 # Set label
            # chose negative or positive rotation
            rotation_degree = random.randrange(35, 50, 1)
            posnegrot = np.random.randint(2)
            if posnegrot == 0:
                #positive rotation
                #rotation_matrix = cv2.getRotationMatrix2D((num_cols/2, num_rows/2), rotation_degree, 1)
                #img = cv2.warpAffine(img, rotation_matrix, (num_cols, num_rows))

                img = rotate_image(img, rotation_degree)
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))
            else:
                # Negative rotation
                rotation_degree = -rotation_degree
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))

        else:
           label = 1
        img = cv2.resize(img, self.size, cv2.INTER_AREA)
        return self.transform(img), self.transform(label)

Ошибка, которую он мне дает,

TypeError: pic should be PIL Image or ndarray. Got class 'int'

. Она должна дать мне img (тензор) и метку (тензор), но Я не думаю, что он делает это правильно.

TypeError                                 Traceback (most recent call last)
<ipython-input-34-f47943b2600c> in <module>
      2     train_loss = 0.0
      3     net.train()
----> 4     for image, label in enumerate(train_loader):
      5         if train_on_gpu:
      6             image, label = image.cuda(), label.cuda()

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-28-6c77357ff619> in __getitem__(self, idx)
     35             label = 1
     36         img = cv2.resize(img, self.size, cv2.INTER_AREA)
---> 37         return self.transform(img), self.transform(label)

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
     99             Tensor: Converted image.
    100         """
--> 101         return F.to_tensor(pic)
    102 
    103     def __repr__(self):

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\functional.py in to_tensor(pic)
     53     """
     54     if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55         raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
     56 
     57     if _is_numpy(pic) and not _is_numpy_image(pic):

TypeError: pic should be PIL Image or ndarray. Got <class 'int'>

1 Ответ

1 голос
/ 19 февраля 2020

Как обсуждалось в комментариях, проблема заключалась в применении преобразования label. Вместо этого label следует просто записать как тензор:

return self.transform(img), torch.tensor(label)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...