Перебор подмножеств из torch.utils.data.random_split - PullRequest
2 голосов
/ 10 февраля 2020

Я сейчас загружаю папку с данными обучения AI. Подпапки представляют названия меток с соответствующими изображениями внутри. Это хорошо работает при использовании загрузчика ImageFolder pyTorch.

def load_dataset():
    data_path = 'C:/example_folder/'

    train_dataset_manual = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )

    train_loader_manual = torch.utils.data.DataLoader(
        train_dataset_manual,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )

    return train_loader_manual

full_dataset = load_dataset()

Теперь я хочу разделить этот набор данных на обучающие и тестовые данные. Для этого я использую функцию random_split:

training_data_size = 0.8

train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

Набор данных full_datase является типом torch.utils.data.dataloader.DataLoader. Я могу перебрать его с помощью al oop следующим образом:

for batch_idx, (data, target) in enumerate(full_dataset):
    print(batch_idx)

train_dataset - это объект типа torch.utils.data.dataset.Subset. Если я попытаюсь l oop через него, я получу:

TypeError 'DataLoader' объект не может быть подписан:

for batch_idx, (data, target) in enumerate(train_dataset):
    print(batch_idx)

Как я могу l oop через это? Я относительно новичок в Python.

Спасибо!

1 Ответ

4 голосов
/ 10 февраля 2020

Вам необходимо применить random_split к Dataset, а не к DataLoader. Набор данных, используемый для определения DataLoader, доступен в элементе DataLoader.dataset.

Например, вы можете сделать

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset.dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)

Затем вы можете перебрать train_loader и test_loader как и ожидалось.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...