Корректная загрузка, разбиение и увеличение данных в Pytorch - PullRequest
1 голос
/ 13 июня 2019

Учебное пособие, похоже, не объясняет, как мы должны загружать, разбивать и делать правильное увеличение.

Давайте получим набор данных, состоящий из автомобилей и кошек.Структура папок будет такой:

data
  cat
    0101.jpg
    0201.jpg
    ...
  dogs
    0101.jpg
    0201.jpg
    ...

Сначала я загрузил набор данных с помощью наборов данных. Функция ImageFolder.Функция Image имеет команду «TRANSFORM», где мы можем установить некоторые команды дополнения, но мы не хотим применять дополнения к тестовому набору данных!Итак, давайте останемся с transform = None.

data = datasets.ImageFolder(root='data')

По-видимому, у нас нет обучения и тестирования структуры папок, и поэтому я предполагаю, что хорошим подходом было бы использование функции split_dataset

    train_size = int(split * len(data))
    test_size = len(data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

Сейчасдавайте загрузим данные следующим образом.

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=8,
                                              shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=True)

Как применить преобразования (увеличение данных) к изображениям "train_loader"?

В основном мне нужно: 1. загрузить данные из структуры папок, описанной выше 2. разбить данные на части для испытаний / обучения 3. применить аугментации для части поезда.

1 Ответ

0 голосов
/ 13 июня 2019

Я не уверен, есть ли рекомендуемый способ сделать это, но именно так я бы обошел эту проблему:

Учитывая, что torch.utils.data.random_split() возвращает Subset, мы не можем ( можетмы не уверены на 100% Я дважды проверил, мы не можем) использовать их внутренние наборы данных, потому что они одинаковы (единственное различие заключается в индексах).В этом контексте я бы реализовал простой класс для применения преобразований, что-то вроде этого:

from torch.utils.data import Dataset

class ApplyTransform(Dataset):
    """
    Apply transformations to a Dataset

    Arguments:
        dataset (Dataset): A Dataset that returns (sample, target)
        transform (callable, optional): A function/transform to be applied on the sample
        target_transform (callable, optional): A function/transform to be applied on the target

    """
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        # yes, you don't need these 2 lines below :(
        if transform is None and target_transform is None:
            print("Am I a joke to you? :)")

    def __getitem__(self, idx):
        sample, target = self.dataset[idx]
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.dataset)

И затем использовал его перед передачей набора данных в загрузчик данных:

import torchvision.transforms as transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
    # ...
])
train_dataset = ApplyTransform(train_dataset, transform=train_transform)

# continue with DataLoaders...
...