почему мой сэмплер классификации изображений pytorch выдает странное значение длины? - PullRequest
0 голосов
/ 11 декабря 2019

Я тренируюсь Классификация изображений
Я сделал поезд и действительные наборы перекрестной проверки

EPOCHS = 5
SAVE_DIR = 'models'
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'please.pt')
from torch.utils.data import DataLoader
best_valid_loss = float('inf')

if not os.path.isdir(f'{SAVE_DIR}'):
    os.makedirs(f'{SAVE_DIR}')
print("start")
for epoch in range(EPOCHS):
    print('================================',epoch ,'================================')
    for i , (train_idx, valid_idx) in enumerate(zip(train_indexes, valid_indexes)):
        print(i,train_idx,valid_idx,len(train_idx),len(valid_idx))

        trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx))
        valloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx))

        print(len(trainloader.dataset),len(valloader.dataset))

        train_loss, train_acc ,model= train(model, device, trainloader, optimizer, criterion)
        valid_loss, valid_acc,model = evaluate(model, device, valloader, criterion)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model,MODEL_SAVE_PATH)

        print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')

start =============================== 0 =============================== 0 [451 452 453 ... 21514 21515 21516] [0 1 2 ... 18402 18403 18404] 18827 269021517 21517

но вывод выборки очень странныйвыход

len (train_idx) и len (valid_idx) = 18827 2690

это правильноНо. при печати

(len (trainloader.dataset), len (valloader.dataset)) = 21517, 21517

что-то не так с сэмплером ???

...