Как использовать torchvision.transforms для увеличения данных задачи сегментации в Pytorch? - PullRequest
0 голосов
/ 03 октября 2019

Я немного озадачен расширением данных, выполненным в PyTorch.

Поскольку мы имеем дело с задачами сегментации, нам нужны данные и маска для того же увеличения данных, но некоторые из них являются случайными, например, случайное вращение.

Keras предоставляет random seed гарантию того, что данные и маска выполняют одну и ту же операцию, как показано в следующем коде:

    data_gen_args = dict(featurewise_center=True,
                         featurewise_std_normalization=True,
                         rotation_range=25,
                         horizontal_flip=True,
                         vertical_flip=True)


    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 1
    image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
    mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

    train_generator = zip(image_generator, mask_generator)

Я не нашел похожего описания в официальном Pytorchдокументации, поэтому я не знаю, как обеспечить синхронную обработку данных и маски.

Pytorch предоставляет такую ​​функцию, но я хочу применить ее к пользовательскому загрузчику данных.

Например:

def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))

    temp_img = np.load(Image_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')

    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = np.uint8(img)
        mask = np.uint8(mask)
        img = self.transforms(img)
        mask = self.transforms(mask)

    return img, mask

В этом случае img и маска будут преобразованы отдельно, поскольку некоторые операции, такие как случайное вращение, являются случайными, поэтому соответствие между маской и изображением может быть изменено,Другими словами, изображение могло вращаться, но маска этого не делала.

РЕДАКТИРОВАТЬ 1

Я использовал метод в augmentations.py , но я получилошибка ::

Traceback (most recent call last):
  File "test_transform.py", line 87, in <module>
    for batch_idx, image, mask in enumerate(train_loader):
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
    img, mask = self.transforms(img, mask)
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
    img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

Это мой код для __getitem__()

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        transforms.ToTensor()
    ]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))
    temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = T.ToPILImage()(np.uint8(img))
        mask = T.ToPILImage()(np.uint8(mask))
        img, mask = self.transforms(img, mask)

    img = T.ToTensor()(img).copy()
    mask = T.ToTensor()(mask).copy()
    return img, mask

РЕДАКТИРОВАТЬ 2

Я обнаружил, что после ToTensor кости между одинаковыми меткамистановится 255 вместо 1, как это исправить?

# Dice computation
def DSC_computation(label, pred):
    pred_sum = pred.sum()
    label_sum = label.sum()
    inter_sum = np.logical_and(pred, label).sum()
    return 2 * float(inter_sum) / (pred_sum + label_sum)

Не стесняйтесь спрашивать, нужен ли еще код для объяснения проблемы.

Ответы [ 2 ]

2 голосов
/ 03 октября 2019

Преобразования, для которых требуются входные параметры, такие как RandomCrop, имеют метод get_param, который возвращает параметры для этого конкретного преобразования. Затем это может быть применено как к изображению, так и к маске, используя функциональный интерфейс преобразований:

from torchvision import transforms
import torchvision.transforms.functional as F

i, j, h, w = transforms.RandomCrop.get_params(input, (100, 100))
input = F.crop(input, i, j, h, w)
target = F.crop(target, i, j, h, w)

Образец, доступный здесь: https://github.com/pytorch/vision/releases/tag/v0.2.0

Полный пример, доступный здесь для VOC & COCO: https://github.com/pytorch/vision/blob/master/references/segmentation/transforms.py https://github.com/pytorch/vision/blob/master/references/segmentation/train.py

Что касается ошибки,

ToTensor() не был переопределен для обработки дополнительного аргумента маски, поэтому он не может быть в data_transforms. Более того, __getitem__ делает ToTensor из img и mask до их возвращения.

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        #transforms.ToTensor()  => remove this line
    ]),
}
1 голос
/ 03 октября 2019

torchvision также предоставляет аналогичные функции [документ] .

Вот простой пример:

import torchvision
from torchvision import transforms

trans = transforms.Compose([transforms.CenterCrop((178, 178)),
                                    transforms.Resize(128),
                                    transforms.RandomRotation(20),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dset = torchvision.datasets.MNIST(data_root, transforms=trans)

РЕДАКТИРОВАТЬ

Краткий пример настройки собственного набора данных CelebA. Обратите внимание, что для применения преобразований вам нужно вызвать список transform в __getitem__.

class CelebADataset(Dataset):
    def __init__(self, root, transforms=None, num=None):
        super(CelebADataset, self).__init__()

        self.img_root = os.path.join(root, 'img_align_celeba')
        self.attr_root = os.path.join(root, 'Anno/list_attr_celeba.txt')
        self.transforms = transforms

        df = pd.read_csv(self.attr_root, sep='\s+', header=1, index_col=0)
        #print(df.columns.tolist())
        if num is None:
            self.labels = df.values
            self.img_name = df.index.values
        else:
            self.labels = df.values[:num]
            self.img_name = df.index.values[:num]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_root, self.img_name[index]))
        # only use blond_hair, eyeglass, male, smile
        indices = [9, 15, 20, 31]
        label = np.take(self.labels[index], indices)
        label[label==-1] = 0

        if self.transforms is not None:
            img = self.transforms(img)

        return np.asarray(img), label

    def __len__(self):
        return len(self.labels)


РЕДАКТИРОВАТЬ 2

Я, наверное, что-то упустил с первого взгляда. Суть вашей проблемы в том, как применить «одинаковую» предварительную обработку данных к img и меткам. Насколько я понимаю, нет доступной встроенной функции Pytorch. Итак, что я делал раньше, так это реализовывал расширение самостоятельно.

class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, img, mask):
        rotate_degree = random.random() * 2 * self.degree - self.degree
        return img.rotate(rotate_degree, Image.BILINEAR), 
                           mask.rotate(rotate_degree, Image.NEAREST)

Обратите внимание, что ввод должен быть в формате PIL. См. this для получения дополнительной информации.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...