Pytorch - Custom DataLoader работает вечно - PullRequest
0 голосов
/ 16 апреля 2019
class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [0]*10000000

    def __getitem__(self, index):
        pid = os.getpid() % WORKER_SIZE
        # My code here only uses pid, doesnt use index

        return torch.tensor(batch.data), torch.tensor(batch.label)

    def __len__(self):
        return len(self.data)

Мне нужен мой загрузчик данных, чтобы работать вечно.Прямо сейчас он всегда завершается после нажатия 10000000 или любого максимального целочисленного размера.Как сделать так, чтобы этот прогон длился вечно, меня не волнует индекс, я им не пользуюсь.Я просто использую рабочие возможности этого класса

1 Ответ

0 голосов
/ 16 апреля 2019

Поскольку вам нужно тренироваться в одном и том же пакете несколько итераций, следующий скелет кода должен работать для вас.

def train(args, data_loader):
    for idx, ex in enumerate(data_loader):
        # iterate over each mini-batches
        # add your code

def validate(args, data_loader):
     with torch.no_grad():
        for idx, ex in enumerate(data_loader):
            # iterate over each mini-batches
            # add your code

# args = dict() containing required parameters
for epoch in range(start_epoch, args.num_epochs):
    # train_loader = data loader for the training data
    train(args, train_loader)

Вы можете использовать загрузчик данных следующим образом.

class ReaderDataset(Dataset):
    def __init__(self, examples):
        # examples = a list of examples
        # add your code

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch item

train_dataset = ReaderDataset(train_examples)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batchify,
        pin_memory=args.cuda,
        drop_last=args.parallel
    )
# batchify is a custom function to prepare the mini-batches
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...