ConcatDataset
- это пользовательский класс, который находится в подклассе от torch.utils.data.Dataset
. Давайте рассмотрим один пример.
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
train_loader = torch.utils.data.DataLoader(
ConcatDataset(
dataset1,
dataset2
),
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
for i, (input, target) in enumerate(train_loader):
...
Здесь два набора данных, а именно dataset1
(список примеров) и dataset2
, объединены в единый обучающий набор данных. Функция __getitem__
возвращает один пример из набора данных и будет использоваться BatchSampler
для формирования обучающих мини-пакетов.
Будут ли возвращенные выборки партии случайным последовательным патроном изодин файл или будет пакетный интервал между несколькими случайными индексами во всех файлах данных?
Поскольку вы объединили все свои файлы данных в один набор данных, теперь это зависит от того, какой BatchSampler
вы используетепробовать мини-партии. В PyTorch реализовано несколько сэмплеров, например, RandomSampler
, SequentialSampler
, SubsetRandomSampler
, WeightedRandomSampler
. См. Их использование в документации .
Вы также можете настроить свой BatchSampler
следующим образом.
class MyBatchSampler(Sampler):
def __init__(self, *params):
# write your code here
def __iter__(self):
# write your code here
# return an iterable
def __len__(self):
# return the size of the dataset
Функция __iter__
должна возвращать итерируемое значениемини-партии. Вы можете реализовать свою логику формирования мини-пакетов в этой функции.
Для случайной выборки отрицательных примеров для обучения одной из альтернатив может быть выбор отрицательных примеров для каждого положительного примера в функции __init__
функции ConcatDataset
класс.