Короткий ответ
Установите drop_last=True
для отбрасывания последней неполной партии
Длинный ответ
На основе вашего кода, создающего сокращенную версиюВ вашем загрузчике данных нет ошибок для размеров партии.
Используя 9 как batch_size
и имея 100 элементов, в последней партии только один элемент.Выполнение приведенного ниже кода приводит к получению.
Установка drop_last = False Печатается последняя строка и печатается «исключение».
0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])
Таким образом, партия производит достаточно хороших элементов партии, чтобы сделать еедо 100 всего
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata
class FlatDirectoryAudioDataset(tdata.Dataset): # customized dataloader
def __init__(self):
self.files = self.__setup_files()
def __len__(self):
return len(self.files)
def __setup_files(self):
return np.array(range(100))
def __getitem__(self, index):
file = self.files[index]
sample = np.array([file])
sample = torch.from_numpy(sample)
return sample
data = FlatDirectoryAudioDataset()
my_dataset = FlatDirectoryAudioDataset()
batch_size = 9
dataloader_my = DataLoader(
my_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True,
drop_last=True)
for i, sample in enumerate(dataloader_my, 0):
print(i, print(type(sample), sample.shape)
if sample.shape[0] != batch_size:
print("Different batch size (last batch)", sample.shape)