Поскольку вам нужно тренироваться в одном и том же пакете несколько итераций, следующий скелет кода должен работать для вас.
def train(args, data_loader):
for idx, ex in enumerate(data_loader):
# iterate over each mini-batches
# add your code
def validate(args, data_loader):
with torch.no_grad():
for idx, ex in enumerate(data_loader):
# iterate over each mini-batches
# add your code
# args = dict() containing required parameters
for epoch in range(start_epoch, args.num_epochs):
# train_loader = data loader for the training data
train(args, train_loader)
Вы можете использовать загрузчик данных следующим образом.
class ReaderDataset(Dataset):
def __init__(self, examples):
# examples = a list of examples
# add your code
def __len__(self):
# return total dataset size
def __getitem__(self, index):
# write your code to return each batch item
train_dataset = ReaderDataset(train_examples)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.data_workers,
collate_fn=batchify,
pin_memory=args.cuda,
drop_last=args.parallel
)
# batchify is a custom function to prepare the mini-batches