Как создать итератор балансирующего цикла в PyTourch? - PullRequest
0 голосов
/ 13 июля 2020

Допустим, у меня 2 класса. И для одного у меня только 17 образцов, а для другого 83. Я хочу всегда иметь равное количество данных из каждого класса за эпоху (в данном случае это означает 17 на 17). Кроме того, я хочу сдвинуть окно выборки по классу, где у меня есть больше данных для каждой эпохи (первые 17, следующие 17, ...).

В настоящее время у меня есть итератор циклической выборки, подобный этому:

class CyclicIterator:
    def __init__(self, loader, sampler):
        self.loader = loader
        self.sampler = sampler
        self.epoch = 0
        self._next_epoch()

    def _next_epoch(self):
        self.iterator = iter(self.loader)
        self.epoch += 1

    def __len__(self):
        return len(self.loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            return next(self.iterator)
        except StopIteration:
            self._next_epoch()
            return next(self.iterator)

Интересно, как заставить все сэмплы из каждого класса иметь равное количество за эпоху?

1 Ответ

1 голос
/ 15 июля 2020

Для сбалансированной партии, что означает равное (или близкое к равному) количество выборок на категорию в каждой партии, есть несколько подходов: образцы). В этом подходе вы можете использовать следующий код:

https://github.com/galatolofederico/pytorch-balanced-batch

-Undersampling (предоставляет количество выборок для всех категорий на основе наименьшего номера категории). По моему опыту, функция ниже делает то же самое с использованием библиотеки PyTorch:

torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

Где веса - это вероятность каждого образца, это зависит от того, сколько образцов в каждой категории у вас есть, например, если ваши данные простые так как эти данные = [0, 1, 0, 0, 1], счетчик класса «0» равен 3, а счетчик класса «1» равен 2 Таким образом, вектор весов равен [1/3, 1/2, 1/3, 1 / 3, 1/2]. С этим вы можете вызвать WeightedRamdomSampler, и он сделает это за вас. Вам нужно вызвать его в Dataloader. Код для его настройки:

sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_dataloader = DataLoader(dataset_train, batch_size=mini_batch,
                              sampler=sampler, shuffle=False,
                              num_workers=1)
...