Pytorch DataLoader завершается ошибкой, когда количество примеров точно не разделено на размер пакета - PullRequest
0 голосов
/ 13 июня 2019

Я написал собственный класс загрузчика данных в pytorch.Но он терпит неудачу при переборе всего количества партий внутри эпохи.Например, представьте, что у меня есть 100 примеров данных, и мой размер пакета равен 9. Он потерпит неудачу в 10-й итерации, сказав, что размер пакета отличается, что даст размер пакета 1 вместо 10. Я поместил свой пользовательский загрузчик данных ниже.Также я описал способ извлечения данных из загрузчика внутри цикла for.

class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader

    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.files = self.__setup_files()

    def __len__(self):
        """
        compute the length of the dataset
        :return: len => length of dataset
        """
        return len(self.files)

    def __setup_files(self):

        file_names = os.listdir(self.data_dir)
        files = []  # initialize to empty list

        for file_name in file_names:

            possible_file = os.path.join(self.data_dir, file_name)
            if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
                files.append(possible_file)

        # return the files list
        return files


    def __getitem__ (self,index):
        sample, _ = librosa.load(self.files[index], 16000)

        if self.transform:
            sample=self.transform(sample)

        sample = torch.from_numpy(sample)    
        return sample


from torch.utils.data import DataLoader 

    my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)

dataloader_my = DataLoader(
        my_dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True)


for (i,batch) in enumerate(dataloader_my,0):  
       print(i)
       if batch.shape[0]!=16:
          print(batch.shape)
          assert batch.shape[0]==16,"Something wrong with the batch size"



Ответы [ 2 ]

2 голосов
/ 19 июня 2019

use drop_last = True utils.DataLoader (набор данных, batch_size = batch_size, shuffle = True, drop_last = True)

https://pytorch.org/docs/stable/data.html

2 голосов
/ 14 июня 2019

Короткий ответ

Установите drop_last=True для отбрасывания последней неполной партии

Длинный ответ

На основе вашего кода, создающего сокращенную версиюВ вашем загрузчике данных нет ошибок для размеров партии.

Используя 9 как batch_size и имея 100 элементов, в последней партии только один элемент.Выполнение приведенного ниже кода приводит к получению.

Установка drop_last = False Печатается последняя строка и печатается «исключение».

0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])

Таким образом, партия производит достаточно хороших элементов партии, чтобы сделать еедо 100 всего

from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata


class FlatDirectoryAudioDataset(tdata.Dataset):  # customized dataloader

    def __init__(self):
        self.files = self.__setup_files()

    def __len__(self):
        return len(self.files)

    def __setup_files(self):
        return np.array(range(100))

    def __getitem__(self, index):
        file = self.files[index]
        sample = np.array([file])
        sample = torch.from_numpy(sample)
        return sample


data = FlatDirectoryAudioDataset()

my_dataset = FlatDirectoryAudioDataset()

batch_size = 9

dataloader_my = DataLoader(
    my_dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True)

for i, sample in enumerate(dataloader_my, 0):
    print(i, print(type(sample), sample.shape)
    if sample.shape[0] != batch_size:
        print("Different batch size (last batch)", sample.shape)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...