Пакеты точек с одинаковым ярлыком на Pytorch - PullRequest
2 голосов
/ 17 марта 2020

Я хочу тренировать нейронную сеть, используя градиентный спуск в партиях, которые содержат N тренировочных точек каждая Я хотел бы, чтобы эти партии содержали только точки с одинаковой меткой, а не отбирались случайным образом из обучающего набора.

Например, если я тренируюсь с использованием MNIST, я хотел бы, чтобы партии выглядели как следующее:

batch_1 = {0,0,0,0,0,0,0,0}

batch_2 = {3,3,3,3,3,3,3,3}

batch_3 = {7,7,7,7,7,7,7,7}

.....

и так далее.

Как я могу сделать это, используя pytorch?

1 Ответ

3 голосов
/ 18 марта 2020

Один из способов сделать это - создать подмножества и загрузчики данных для каждого класса, а затем выполнить итерацию путем случайного переключения между загрузчиками данных на каждой итерации:

import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np

dataset = MNIST('path/to/mnist_root/', 
                transform=transforms.ToTensor(),
                download=True)

class_inds = [torch.where(dataset.targets == class_idx)[0]
              for class_idx in dataset.class_to_idx.values()]

dataloaders = [
    DataLoader(
        dataset=Subset(dataset, inds),
        batch_size=8,
        shuffle=True,
        drop_last=False)
    for inds in class_inds]

epochs = 1

for epoch in range(epochs):
    iterators = list(map(iter, dataloaders))   
    while iterators:         
        iterator = np.random.choice(iterators)
        try:
            images, labels = next(iterator)   
            print(labels)
            # do_more_stuff()

        except StopIteration:
            iterators.remove(iterator)

Это будет работать с любым набором данных (не только MNIST). ). Вот результат печати этикеток на каждой итерации:

tensor([6, 6, 6, 6, 6, 6, 6, 6])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([5, 5, 5, 5, 5, 5, 5, 5])
tensor([8, 8, 8, 8, 8, 8, 8, 8])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
...
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1])

Обратите внимание, что при установке drop_last=False, здесь и там будут пакеты, содержащие менее batch_size элементов. При установке значения True все партии будут одинакового размера, но некоторые точки данных будут отброшены.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...