Pytorch: как создать собственный набор данных с помощью справочной таблицы - PullRequest
0 голосов
/ 21 июня 2019

У меня есть файл reference.csv, который имеет три столбца: Тип, Класс и Путь. Вот первые 5 примеров строк:

"Type","Class","Path"
"train","A","./path1/001.jpg"
"train","A","./path2/002.jpg"
"test","C","./path3/003.jpg"
"train","B","./path4/001.jpg"
"test","B","./path5/002.jpg"
...

В более удобном для просмотра формате:

|----------------------|------------------|------------------|
|         Type         |       Class      |       Path       |
|----------------------|------------------|------------------|
|        train         |         A        | ./path1/001.jpg  |
|----------------------|------------------|------------------|
|        train         |         A        | ./path2/002.jpg  |
|----------------------|------------------|------------------|
|        train         |         C        | ./path3/003.jpg  |
|----------------------|------------------|------------------|
|        test          |         B        | ./path4/001.jpg  |
|----------------------|------------------|------------------|
|        test          |         B        | ./path5/002.jpg  |
|----------------------|------------------|------------------|

Я хочу создать класс набора данных (torch.utils.data.Dataset) для чтения изображений, чтобы я мог использовать DataLoader (torch.utils.data.DataLoader).

Как правильно создать пользовательский набор данных с использованием справочной таблицы?

1 Ответ

1 голос
/ 21 июня 2019

Если мы хотим создать собственный набор данных, который считывает местоположения изображений из этого файла CSV, тогда мы можем сделать что-то вроде следующего. Ваша логика может отличаться.

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()

        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)

        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])

        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])

        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])

        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):

        # Get image name from the pandas df
        single_image_name = self.image_arr[index]

        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        some_operation = self.operation_arr[index]

        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    custom_mnist_from_images =  \
        CustomDatasetFromImages('../data/mnist_labels.csv')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...