Как PyTorch обрабатывает метки при загрузке файлов изображений / масок для сегментации изображений? - PullRequest
0 голосов
/ 05 февраля 2019

Я запускаю проект сегментации изображений с помощью PyTorch.У меня есть уменьшенный набор данных в папке и 2 подпапки - «изображение» для хранения изображений и «маска» для маскированных изображений.Изображения и маски представляют собой файлы .png с 3 каналами и 256x256 пикселей.Поскольку это сегментация изображения, маркировка должна выполняться попиксельно.Я работаю только с 2 классами на данный момент для простоты.До сих пор я достиг следующего:

Я смог загрузить свои файлы в классы "images" или "mask" на

root_dir="./images_masks"
train_ds_untransf = torchvision.datasets.ImageFolder(root=root_dir)
train_ds_untransf.classes
Out[621]:
['images', 'masks']  

и преобразовать данные в тензоры

from torchvision import transforms
train_trans = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.ImageFolder(root=root_dir,transform=train_trans)

Каждый тензор в этом "train_dataset" имеет следующую форму:

train_dataset[1][0].shape
torch.Size([3, 256, 256])

Теперь мне нужно передать загруженные данные в модель CNN и изучить PyTorch DataLoader для этого

train_dataloaded = DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=4)

Я использую следующий код для проверки результирующей формы тензора

for x, y in train_dl:
    print (x.shape)
    print (y.shape)
    print(y)

и получаю

torch.Size([2, 3, 256, 256])
torch.Size([2])
tensor([0, 0])
torch.Size([2, 3, 256, 256])
torch.Size([2])
tensor([0, 1])
.
.
.

Фигуры кажутся правильными.Тем не менее, первая проблема заключается в том, что я получил тензоры из одной и той же папки, обозначенные некоторыми тензорами "y" с одинаковым значением [0, 0].Я ожидаю, что все они представляют собой [1, 0]: 1 представляет изображение, 0 представляет маски.

Вторая проблема заключается в том, что, хотя документация ясна, когда метки представляют собой целые изображения, она не ясна в отношениикак применять его для надписей на уровне пикселей, и я уверен, что надписи не верны.

Что было бы альтернативой для правильной маркировки этого набора данных?

спасибо

1 Ответ

0 голосов
/ 05 февраля 2019

Класс torchvision.datasets.ImageFolder предназначен для классификации изображений задач, а не для сегментации;поэтому он ожидает одну целочисленную метку на изображение, и метка определяется подпапкой, в которой хранятся изображения.Итак, что касается вашего загрузчика данных, у вас есть два класса изображений: «изображения» и «маски», и ваша сеть пытается различить их.

На самом деле вам нужна другая реализациянабора данных, который для каждого __getitem__ возвращает изображение и соответствующую маску.Вы можете увидеть примеры таких классов здесь .

Кроме того, немного странно, что ваши двоичные метки по пикселям хранятся как 3-канальное изображение.Маски сегментации обычно хранятся как одноканальное изображение.

...