Как я могу сохранить экземпляр DataLoader PyTorch? - PullRequest
1 голос
/ 02 апреля 2020

Я хочу сохранить экземпляр PyTorch torch.utils.data.dataloader.DataLoader, чтобы я мог продолжить обучение с того места, где остановился (сохраняя случайное начало, состояния и все остальное).

1 Ответ

0 голосов
/ 04 апреля 2020

Это довольно просто. Нужно создать свой собственный Sampler, который берет начальный индекс и самостоятельно перетасовывает данные:

import random
from torch.utils.data.dataloader import Sampler


random.seed(224)  # use a fixed number


class MySampler(Sampler):
    def __init__(self, data, i=0):
        self.seq = list(range(len(data)))[i * batch_size:]
        random.shuffle(self.seq)

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

Теперь сохраните последний индекс где-нибудь i и в следующий раз создайте экземпляр DataLoader, используя его :

train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=batch_size, 
                               sampler=train_sampler,
                               shuffle=False)  # don't forget to set DataLoader's shuffle to False

Это очень полезно при тренировках на Colab.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...