Как использовать Batchsampler в Dataloader - PullRequest
0 голосов
/ 27 апреля 2020

Мне нужно использовать BatchSampler в пределах pytorch DataLoader вместо многократного вызова __getitem__ набора данных (удаленный набор данных, каждый запрос дорогой).
Я не могу понять, как использовать пакетный сэмплер с любым данным набором данных.

например,

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, idx):
        return self.ddf[idx] --------> This is as expensive as a batch call

    def get_batch(self, batch_idx):
        return self.ddf[batch_idx]

my_loader = DataLoader(MyDataset(remote_ddf), 
           batch_sampler=BatchSampler(Sampler(), batch_size=3))

Вещь, которую я не понимаю, также не нашел ни одного примера в Интернете или в документации по факелам, как я могу использовать мою функцию get_batch вместо функции __getitem__.
Редактировать: Следуя ответу Шимона Маске, это то, что я пытался, и все же, \_\_get_item__ получает по одному индексу за каждый вызов вместо списка размером batch_size

class Dataset(Dataset):

    def __init__(self):
       ...

    def __len__(self):
        ...

    def __getitem__(self, batch_idx):  ------> here I get only one index
        return self.wiki_df.loc[batch_idx]


loader = DataLoader(
                dataset=dataset,
                batch_sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )

1 Ответ

1 голос
/ 27 апреля 2020

Вы не можете использовать get_batch вместо __getitem__, и я не вижу смысла делать это так.

torch.utils.data.BatchSampler берет индексы от вашего Sampler() экземпляр (в данном случае 3 из них) и возвращает его как list, чтобы их можно было использовать в вашем методе MyDataset __getitem__ (проверьте исходный код , большинство сэмплеров и данных связанные утилиты просты в использовании, если вам это нужно).

Я предполагаю, что ваша self.ddf поддерживает нарезку списка (например, self.ddf[[25, 44, 115]] возвращает значения правильно и использует только один дорогой вызов). В этом случае просто переключите get_batch на __getitem__, и вам будет нужно go.

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, batch_idx):
        return self.ddf[batch_idx] -> batch_idx is a list

РЕДАКТИРОВАТЬ: Вы должны указать batch_sampler как sampler, в противном случае партия будет разделена на отдельные индексы. Это должно быть хорошо:

loader = DataLoader(
    dataset=dataset,
    # This line below!
    sampler=BatchSampler(
        SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
    ),
    num_workers=self.hparams.num_data_workers,
)
...