Как применить фильтры OpenCV к набору данных Pytorch? - PullRequest
0 голосов
/ 14 марта 2020

Я использовал следующие шаги для предварительной обработки одного изображения с использованием OpenCV. Теперь я хочу применить эти шаги предварительной обработки ко всему набору данных перед тем, как приступить к обучению модели в Pytorch. Как это может быть сделано?

im = cv2.imread(image_path)
im_nonoise = cv2.medianBlur(im, 3)
imgray = cv2.cvtColor(im_nonoise,cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl1 = clahe.apply(imgray)
ret,thresh = cv2.threshold(cl1,110,255,0)
image, contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
img = cv2.drawContours(image, contours, -1, (250,100,120))

И я загружаю данные, используя

data = datasets.ImageFolder(train_dir,transform=transform)
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size,sampler=train_sampler)

1 Ответ

0 голосов
/ 14 марта 2020

Вы можете создать свой собственный класс набора данных (полученный из ImageFolder) и перегрузить только метод __getitem__:

class MySpecialDataset(datasets.ImageFolder):
  def __init__(self, root, loader=default_loader, is_valid_file=None):
    super(MySpecialDataset, self).__init__(root=root, loader=loader, is_valid_file=is_valid_file)

  def __getitem__(self, index):
    image_path, target = self.samples[index]
    # do your magic here
    im = cv2.imread(image_path)
    im_nonoise = cv2.medianBlur(im, 3)
    imgray = cv2.cvtColor(im_nonoise,cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(imgray)
    ret,thresh = cv2.threshold(cl1,110,255,0)
    image, contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    img = cv2.drawContours(image, contours, -1, (250,100,120))
    # you need to convert img from np.array to torch.tensor
    # this has to be done CAREFULLY!
    sample = torchvision.transforms.ToTensor()(img)
    return sample, target

Получив этот набор данных, вы можете использовать его с базой c pytorch's DataLoader:

data = MySpecialDataset(train_dir)
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size,sampler=train_sampler)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...