При заданном параметре train_frac=0.8
эта функция разделит dataset
на 80%, 10%, 10%:
import torch, itertools
from torch.utils.data import TensorDataset
def dataset_split(dataset, train_frac):
'''
param dataset: Dataset object to be split
param train_frac: Ratio of train set to whole dataset
Randomly split dataset into a dictionary with keys, based on these ratios:
'train': train_frac
'valid': (1-split_frac) / 2
'test': (1-split_frac) / 2
'''
assert split_frac >= 0 and split_frac <= 1, "Invalid training set fraction"
length = len(dataset)
# Use int to get the floor to favour allocation to the smaller valid and test sets
train_length = int(length * train_frac)
valid_length = int((length - train_length) / 2)
test_length = length - train_length - valid_length
dataset = random_split(dataset, (train_length, valid_length, test_length))
dataset = {name: set for name, set in zip(('train', 'valid', 'test'), sets)}
return dataset