Как разделить набор данных на настраиваемый обучающий набор и настраиваемый набор проверки с помощью pytorch? - PullRequest
1 голос
/ 06 мая 2020

Я использую набор данных, не относящийся к torchvision, и извлек его с помощью метода ImageFolder. Я пытаюсь разбить набор данных на набор проверки 20% и набор обучения 80%. Я могу найти только этот метод (random_split) из библиотеки PyTorch, который позволяет разделить набор данных. Однако каждый раз это происходит случайно. Мне интересно, есть ли способ разделить набор данных с указанным c количеством в библиотеке PyTorch?

Это мой код для извлечения набора данных и его случайного разделения.

transformations = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)

####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])

#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)

1 Ответ

1 голос
/ 06 мая 2020

Если вы посмотрите «под капот» random_split, вы увидите, что он использует torch.utils.data.Subset для фактического разделения. Вы можете сделать это самостоятельно с фиксированными индексами:

import random

indices = list(range(len(TrafficSignSet))
random.seed(310)  # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...