pytorch batch_size не возвращает правильную партию? - PullRequest
0 голосов
/ 12 июля 2020

Независимо от того, что я поставил для batch_size, значение batch_size по умолчанию равно 1. Вот мой код

train_dataset = DataLoader(dataset=dataset,
                       batch_size=4,
                       shuffle=True,
                       num_workers=0)

, а набор данных - это настраиваемый набор данных следующим образом:

class ImageDataset(data.Dataset):

def __init__(self, root_dir, num_augments=2, transform=None):
    
    self.root_dir = root_dir
    self.img_names = os.listdir(root_dir)[::600]
    self.num_augments = num_augments
    self.transform = transform
    
def __getitem__(self, index):
    
    output = []
    img = Image.open(self.root_dir + '/' + self.img_names[index]).convert('RGB')
        
    for i in range(self.num_augments):
        if self.transform is not None:
            img_transform = self.transform(img)
            
        output.append(img_transform)
        
    output = torch.stack(output, axis=0)
        
    return output
        
def __len__(self):
    
    return len(self.img_names)

Я ожидаю каждая партия должна иметь размер [batch, num_augments, 3, height, width], но я получаю [1, num_augments, 3, height, width] независимо от размера моей партии.

...