Python, набор данных класса, как объединить изображения с соответствующими метками в pytorch - PullRequest
1 голос
/ 17 июня 2020

Я новичок в PyTorch, и последние пару дней я боролся с классом Dataset, который позволяет вам создавать свой собственный набор данных.

Я работаю с этим набором данных (https://www.kaggle.com/ianmoone0617/flower-goggle-tpu-classification/kernels), проблема в том, что изображения и их метки находятся в отдельных папках, и я не могу понять, как их объединить.

Это код, который я использую:

class MyDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image_name = os.path.join(self.root_dir, self.labels.iloc[index, 0])
        image = io.imread(image_name)

        if self.transform:
            image = self.transform(image)

        return (image, labels)

В то время как структура папок следующая: structure of the folders]

Я действительно хочу понять это, поэтому заранее спасибо, ребята !!

1 Ответ

0 голосов
/ 17 июня 2020

Похоже, вы почти у цели. Есть много способов справиться с этим. Например, вы можете прочитать оба файла csv во время инициализации, чтобы создать словарь, который сопоставляет строку метки в flowers_idx.csv с индексом метки, указанным в flowers_label.csv.

import os
import pandas as pd
import torch
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_csv, label_csv, root_dir, transform=None):
        self.data_entries = pd.read_csv(data_csv)
        self.root_dir = root_dir
        self.transform = transform

        label_map = pd.read_csv(label_csv)
        self.label_str_to_idx = {label_str: label_idx for label_idx, label_str in label_map.iloc}

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.item()

        label = self.label_str_to_idx[self.data_entries.iloc[index, 1]] 
        image_path = os.path.join(self.root_dir, f'{self.data_entries.iloc[index, 0]}.jpeg')

        # torchvision datasets generally return PIL image rather than numpy ndarray
        image = default_loader(image_path)

        # alternative to load ndarray using skimage.io
        # image = io.imread(image_path)

        if self.transform:
            image = self.transform(image)

        return (image, label)

Обратите внимание, что это возвращает PIL изображения, а не ndarrays, так как это обычно то, что обычно возвращает наборы данных torchvision. Это также хорошо, так как многие преобразования torchvision могут быть применены только к изображениям PIL.

На данный момент простым вариантом использования может быть:

import torchvision.transforms as tt

dataset_dir = '/home/jodag/datasets/527293_966816_bundle_archive'
# TODO add more transforms/data-augmentation etc...
transform = tt.Compose((
    tt.ToTensor(),
))

dataset = MyDataset(
    os.path.join(dataset_dir, 'flowers_idx.csv'),
    os.path.join(dataset_dir, 'flowers_label.csv'),
    os.path.join(dataset_dir, 'flower_tpu/flower_tpu/flowers_google/flowers_google'),
    transform)

image, label = dataset[0]

Во время обучения или проверки вы, вероятно, будете использовать a DataLoader для выборки набора данных.

...