Похоже, вы почти у цели. Есть много способов справиться с этим. Например, вы можете прочитать оба файла csv во время инициализации, чтобы создать словарь, который сопоставляет строку метки в flowers_idx.csv
с индексом метки, указанным в flowers_label.csv
.
import os
import pandas as pd
import torch
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_csv, label_csv, root_dir, transform=None):
self.data_entries = pd.read_csv(data_csv)
self.root_dir = root_dir
self.transform = transform
label_map = pd.read_csv(label_csv)
self.label_str_to_idx = {label_str: label_idx for label_idx, label_str in label_map.iloc}
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.item()
label = self.label_str_to_idx[self.data_entries.iloc[index, 1]]
image_path = os.path.join(self.root_dir, f'{self.data_entries.iloc[index, 0]}.jpeg')
# torchvision datasets generally return PIL image rather than numpy ndarray
image = default_loader(image_path)
# alternative to load ndarray using skimage.io
# image = io.imread(image_path)
if self.transform:
image = self.transform(image)
return (image, label)
Обратите внимание, что это возвращает PIL
изображения, а не ndarrays, так как это обычно то, что обычно возвращает наборы данных torchvision. Это также хорошо, так как многие преобразования torchvision могут быть применены только к изображениям PIL.
На данный момент простым вариантом использования может быть:
import torchvision.transforms as tt
dataset_dir = '/home/jodag/datasets/527293_966816_bundle_archive'
# TODO add more transforms/data-augmentation etc...
transform = tt.Compose((
tt.ToTensor(),
))
dataset = MyDataset(
os.path.join(dataset_dir, 'flowers_idx.csv'),
os.path.join(dataset_dir, 'flowers_label.csv'),
os.path.join(dataset_dir, 'flower_tpu/flower_tpu/flowers_google/flowers_google'),
transform)
image, label = dataset[0]
Во время обучения или проверки вы, вероятно, будете использовать a DataLoader
для выборки набора данных.