Добавление пользовательских меток в pytorch dataloader / dataset не работает для пользовательского набора данных - PullRequest
0 голосов
/ 26 июня 2019

Я работаю над конкурсом изображений кактусов на Kaggle и пытаюсь использовать загрузчик данных PyTorch для моей CNN.Однако я сталкиваюсь с проблемой, когда я не могу установить метки для учебного набора.Изображения обучающего набора приведены в папке, а метки - в файле CSV.Это мой код.

 train = torchvision.datasets.ImageFolder(root='../input/train', 
 transform=transform)

 train.targets = torch.from_numpy(df['has_cactus'].values)

 train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)

 for i, data in enumerate(train_loader, 0):
     print(data[1])

Этот код выводит пакетные тензоры всех нулей, что явно неверно, поскольку подавляющее большинство меток (если вы посмотрите на фрейм данных) являются единицами.Я считаю, что это проблема с назначением меток для «train.targets».Если «train.targets» напечатан до назначения других меток, он возвращает тензор всех нулей, что согласуется с неверными результатами, которые я получаю.Как мне исправить эту проблему?

Ответы [ 2 ]

1 голос
/ 26 июня 2019

Я обычно наследую встроенный класс DataSet следующим образом:

from torch.utils.data import DataLoader
class DataSet:

    def __init__(self, root):
        """Init function should not do any heavy lifting, but
            must initialize how many items are available in this data set.
        """

        self.ROOT = root
        self.images = read_images(root + "/images")
        self.labels = read_labels(root + "/labels")

    def __len__(self):
        """return number of points in our dataset"""

        return len(self.images)

    def __getitem__(self, idx):
        """ Here we have to return the item requested by `idx`
            The PyTorch DataLoader class will use this method to make an iterable for
            our training or validation loop.
        """

        img = images[idx]
        label = labels[idx]

        return img, label

И теперь вы можете создать экземпляр этого класса как

ds = Dataset('../input/train')

Теперь вы можете создать экземплярDataLoader:

dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)

Это создаст пакеты ваших данных, к которым вы можете получить доступ как:

for image, label in dl:
    print(label)
0 голосов
/ 27 июня 2019

Вы можете создать пользовательский загрузчик набора данных, унаследовав встроенный класс Dataset, как упомянул @Sai Krishnan.

from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image

VOC_CLASSES = ('background',  # always index 0
               'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair',
               'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant',
               'sheep', 'sofa', 'train', 'tvmonitor')

NUM_CLASSES = len(VOC_CLASSES) + 1

class customDataset(Dataset):
    """Pascal VOC 2007 Dataset"""
    def __init__(self, list_file, img_dir, mask_dir, transform=None):
        # list of images to load in a .txt file
        self.images = open(list_file, "rt").read().split("\n")[:-1]
        self.transform = transform
        # note that in the .txt file the image names are stored without the extension(.jpg or .png)
        self.img_extension = ".jpg"
        self.mask_extension = ".png"

        self.image_root_dir = img_dir
        self.mask_root_dir = mask_dir
        # can comment the line below
        self.counts = self.__compute_class_probability()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        name = self.images[index]
        image_path = os.path.join(self.image_root_dir, name + self.img_extension)
        mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)

        image = self.load_image(path=image_path)
        gt_mask = self.load_mask(path=mask_path)
        data = {
                    'image': torch.FloatTensor(image),
                    'mask' : torch.LongTensor(gt_mask)
                    }
        return data

    def __compute_class_probability(self):
        counts = dict((i, 0) for i in range(NUM_CLASSES))

        for name in self.images:
            mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)

            raw_image = Image.open(mask_path).resize((224, 224))
            imx_t = np.array(raw_image).reshape(224*224)
            imx_t[imx_t==255] = len(VOC_CLASSES)

            for i in range(NUM_CLASSES):
                counts[i] += np.sum(imx_t == i)
        return counts

    def get_class_probability(self):
        values = np.array(list(self.counts.values()))
        p_values = values/np.sum(values)
        return torch.Tensor(p_values)

    def load_image(self, path=None):
        # can use any other library too like OpenCV as long as you are consistent with it
        raw_image = Image.open(path)
        raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
        imx_t = np.array(raw_image, dtype=np.float32)/255.0

        return imx_t
    # can comment the below function if not needed
    def load_mask(self, path=None):
        raw_image = Image.open(path)
        raw_image = raw_image.resize((224, 224))
        imx_t = np.array(raw_image)
        imx_t[imx_t==255] = len(VOC_CLASSES)
        return imx_t

Когда класс готов, вы можете создать его экземпляр и использовать его.

data_root = os.path.join("VOCdevkit", "VOC2007")
list_file_path = os.path.join(data_root, "ImageSets", "Segmentation", "train.txt")
img_dir = os.path.join(data_root, "JPEGImages")
mask_dir = os.path.join(data_root, "SegmentationClass")


objects_dataset = customDataset(list_file=list_file_path,
                                        img_dir=img_dir,
                                        mask_dir=mask_dir)
sample = objects_dataset[k]
image, mask = sample['image'], sample['mask']
image.transpose_(0, 2)

fig = plt.figure()

a = fig.add_subplot(1,2,1)
plt.imshow(image)

a = fig.add_subplot(1,2,2)
plt.imshow(mask)

plt.show()

Убедитесь, что вы правильно вставили пути к файлам.Также вам нужно будет правильно загрузить метки в классе customDataset().

Примечание. Этот фрагмент является лишь примером того, каким должен быть пользовательский загрузчик данных.Вам нужно будет внести в него соответствующие изменения, чтобы он работал в вашем случае.

...