Это довольно просто. Нужно создать свой собственный Sampler
, который берет начальный индекс и самостоятельно перетасовывает данные:
import random
from torch.utils.data.dataloader import Sampler
random.seed(224) # use a fixed number
class MySampler(Sampler):
def __init__(self, data, i=0):
self.seq = list(range(len(data)))[i * batch_size:]
random.shuffle(self.seq)
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
Теперь сохраните последний индекс где-нибудь i
и в следующий раз создайте экземпляр DataLoader
, используя его :
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=False) # don't forget to set DataLoader's shuffle to False
Это очень полезно при тренировках на Colab.