Вы можете создать свой собственный класс набора данных (полученный из ImageFolder
) и перегрузить только метод __getitem__
:
class MySpecialDataset(datasets.ImageFolder):
def __init__(self, root, loader=default_loader, is_valid_file=None):
super(MySpecialDataset, self).__init__(root=root, loader=loader, is_valid_file=is_valid_file)
def __getitem__(self, index):
image_path, target = self.samples[index]
# do your magic here
im = cv2.imread(image_path)
im_nonoise = cv2.medianBlur(im, 3)
imgray = cv2.cvtColor(im_nonoise,cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl1 = clahe.apply(imgray)
ret,thresh = cv2.threshold(cl1,110,255,0)
image, contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
img = cv2.drawContours(image, contours, -1, (250,100,120))
# you need to convert img from np.array to torch.tensor
# this has to be done CAREFULLY!
sample = torchvision.transforms.ToTensor()(img)
return sample, target
Получив этот набор данных, вы можете использовать его с базой c pytorch's DataLoader
:
data = MySpecialDataset(train_dir)
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size,sampler=train_sampler)