x_train, y_train = torch.rand((708, 256, 3)), torch.rand((708, 4)) # data
class training_set(data.Dataset):
def __init__(self,X,Y):
self.X = X # set data
self.Y = Y # set lables
def __len__(self):
return len(self.X) # return length
def __getitem__(self, idx):
return [self.X[idx], self.Y[idx]] # return list of batch data [data, labels]
training_dataset = training_set(x_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=True)
На самом деле вам не нужно использовать пользовательский набор данных, потому что в вашем случае это простой набор данных. Сначала вы можете изменить на TensorDataset
, чтобы использовать
training_dataset = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=True)
, чтобы оба возвращали одинаковые результаты.