Мы можем использовать для этого модуль torch.utils.data
, выполнив следующие шаги:
Создание класса набора данных для загрузки пользовательских данных путем наследования torch.utils.data.Dataset
Создание объекта набора данных путем передачи данных в экземпляр пользовательского класса набора данных
Используйте torch.utils.data.DataLoader
для загрузки набора данных и получения пакетов
при условии, что вы загрузили данные из каталога, в обучающие и тестовые массивы numpy, вы можете наследовать от torch.utils.data.Dataset
класса, чтобы создать объект набора данных
class MyDataset(Dataset):
def __init__(self, x, y):
super(MyDataset, self).__init__()
assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
self.x = x
self.y = y
def __len__(self):
return self.y.shape[0]
def __getitem__(self, index):
return self.x[index], self.y[index]
Затем создайте свой объект набора данных
traindata = MyDataset(train_x, train_y)
Наконец, используйте DataLoader
для создания мини-пакетов
trainloader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)