Как увеличить размер пакета в pytorch? - PullRequest
1 голос
/ 22 сентября 2019

В pytorch DataLoader разделяет набор данных на пакеты заданного размера с дополнительными опциями перестановки и т. Д., Которые затем можно зациклить.

Но если мне нужно увеличить размер пакета, например, первые 10партия размера 50, следующие 5 партий размера 100 и т. д., как лучше всего это сделать?

Я попытался разбить тензор, затем объединить их:

#10x50 + 5*100
originalTensor = torch.randn(1000, 80)
split1=torch.split(originalTensor, 500, dim=0)
split2=torch.split(list(split1)[0], 100, dim=0)

После этого естьспособ передать составной тензор в dataLoader или любой другой способ напрямую превратить объединенный тензор в генератор (который может потерять тасование и другие функции)?

1 Ответ

1 голос
/ 22 сентября 2019

Я думаю, что вы можете сделать это, просто указав batch_sampler для вашего DataLoader.
Например:

class VaryingSizeBatchSampler(Sampler):
    r"""Wraps another sampler to yield a varying-size mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size_fn (function): Size of current mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    """

    def __init__(self, sampler, batch_size_fn, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        self.sampler = sampler
        self.batch_size_fn = batch_size_fn
        self.drop_last = drop_last
        self.batch_counter = 0

    def __iter__(self):
        batch = []
        cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == cur_batch_size:
                yield batch
                self.batch_counter += 1
                cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size                
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch            

    def __len__(self):
        raise NotImplementedError('You need to implement it yourself!')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...