Я пытаюсь использовать этот класс, который я построил для набора данных, но он говорит, что это должен быть PIL или ndarray. Я не совсем уверен, что с ним не так. Вот класс, который я использую
class RotateDataset(Dataset):
def __init__(self, image_list, size,transform = None):
self.image_list = image_list
self.size = size
self.transform = transform
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img = cv2.imread(self.image_list[idx])
image_height, image_width = img.shape[:2]
print("ID: ", idx)
if idx % 2 == 0:
label = 0 # Set label
# chose negative or positive rotation
rotation_degree = random.randrange(35, 50, 1)
posnegrot = np.random.randint(2)
if posnegrot == 0:
#positive rotation
#rotation_matrix = cv2.getRotationMatrix2D((num_cols/2, num_rows/2), rotation_degree, 1)
#img = cv2.warpAffine(img, rotation_matrix, (num_cols, num_rows))
img = rotate_image(img, rotation_degree)
img = crop_around_center(img, *largest_rotated_rect(image_width,
image_height,
math.radians(rotation_degree)))
else:
# Negative rotation
rotation_degree = -rotation_degree
img = crop_around_center(img, *largest_rotated_rect(image_width,
image_height,
math.radians(rotation_degree)))
else:
label = 1
img = cv2.resize(img, self.size, cv2.INTER_AREA)
return self.transform(img), self.transform(label)
Ошибка, которую он мне дает,
TypeError: pic should be PIL Image or ndarray. Got class 'int'
. Она должна дать мне img (тензор) и метку (тензор), но Я не думаю, что он делает это правильно.
TypeError Traceback (most recent call last)
<ipython-input-34-f47943b2600c> in <module>
2 train_loss = 0.0
3 net.train()
----> 4 for image, label in enumerate(train_loader):
5 if train_on_gpu:
6 image, label = image.cuda(), label.cuda()
~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-28-6c77357ff619> in __getitem__(self, idx)
35 label = 1
36 img = cv2.resize(img, self.size, cv2.INTER_AREA)
---> 37 return self.transform(img), self.transform(label)
~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
99 Tensor: Converted image.
100 """
--> 101 return F.to_tensor(pic)
102
103 def __repr__(self):
~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\functional.py in to_tensor(pic)
53 """
54 if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55 raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
56
57 if _is_numpy(pic) and not _is_numpy_image(pic):
TypeError: pic should be PIL Image or ndarray. Got <class 'int'>