Я хочу создать несколько подмножеств набора данных MNIST, представленных в Pytorch. Каждое подмножество должно иметь разные классы. Я попробовал следующее:
def split_MNIST(mnist_set, digits):
dset = mnist_set
classes = []
indices = dset.targets == digits[0]
classes.append(dset.classes[digits[0]])
if len(digits) > 1:
for digit in digits[1:]:
idx = dset.targets == digit
indices = indices + idx
classes.append(dset.classes[digit])
dset.targets = dset.targets[indices]
dset.data = dset.data[indices]
dset.classes = classes
return dset
train = datasets.MNIST("../data", train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()]))
test =datasets.MNIST("../data", train=False, download=True,
transform=transforms.Compose([transforms.ToTensor()]))
tr = split_MNIST(train, [1,2,3])
trainset = torch.utils.data.DataLoader(tr, batch_size=16, shuffle=True)
Это работает, но вместо создания нового набора данных он фактически меняет исходную переменную поезда. Есть ли способ создать клон набора данных вместо того, чтобы сохранить исходный?