Один из способов сделать это - создать подмножества и загрузчики данных для каждого класса, а затем выполнить итерацию путем случайного переключения между загрузчиками данных на каждой итерации:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np
dataset = MNIST('path/to/mnist_root/',
transform=transforms.ToTensor(),
download=True)
class_inds = [torch.where(dataset.targets == class_idx)[0]
for class_idx in dataset.class_to_idx.values()]
dataloaders = [
DataLoader(
dataset=Subset(dataset, inds),
batch_size=8,
shuffle=True,
drop_last=False)
for inds in class_inds]
epochs = 1
for epoch in range(epochs):
iterators = list(map(iter, dataloaders))
while iterators:
iterator = np.random.choice(iterators)
try:
images, labels = next(iterator)
print(labels)
# do_more_stuff()
except StopIteration:
iterators.remove(iterator)
Это будет работать с любым набором данных (не только MNIST). ). Вот результат печати этикеток на каждой итерации:
tensor([6, 6, 6, 6, 6, 6, 6, 6])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([5, 5, 5, 5, 5, 5, 5, 5])
tensor([8, 8, 8, 8, 8, 8, 8, 8])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
...
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1])
Обратите внимание, что при установке drop_last=False
, здесь и там будут пакеты, содержащие менее batch_size
элементов. При установке значения True все партии будут одинакового размера, но некоторые точки данных будут отброшены.