У вас есть пара опций.
- Самый простой вариант, если наличие большого количества маленьких файлов не является проблемой, - это предварительно обработать каждый объект json в один файл.Затем вы можете просто прочитать каждый из них в зависимости от запрошенного индекса.Например,
class SingleFileDataset(Dataset):
def __init__(self, list_of_file_paths):
self.list_of_file_paths = list_of_file_paths
def __getitem__(self, index):
return np.load(self.list_of_file_paths[index]) # Or equivalent reading code for single file
Вы также можете разбить данные на постоянное количество файлов, а затем рассчитать, исходя из индекса, в каком файле находится образец. Затем вам нужно открыть этот файл в памяти и прочитать соответствующий индекс.Это дает компромисс между доступом к диску и использованием памяти.Предположим, у вас есть
n
семплов, и мы разбиваем семплы на
c
файлов равномерно во время предварительной обработки.Теперь, чтобы прочитать образец с индексом
i
, мы должны сделать
class SplitIntoFilesDataset(Dataset):
def __init__(self, list_of_file_paths, n_splits):
self.list_of_file_paths = list_of_file_paths
self.n_splits = n_splits
def __getitem__(self, index):
# index // n_splits is the relevant file, and
# index % len(self) is the index in in that file
file_to_load = self.list_of_file_paths[index // self.n_splits]
# Load file
file = np.load(file)
datapoint = file[index % len(self)]
Наконец, вы можете использовать файл HDF5 , который разрешает доступ к строкам на диске.Возможно, это лучшее решение, если у вас много данных, так как данные будут близко к диску.Здесь есть реализация здесь , которую я скопировал ниже:
import h5py
import torch
import torch.utils.data as data
class H5Dataset(data.Dataset):
def __init__(self, file_path):
super(H5Dataset, self).__init__()
h5_file = h5py.File(file_path)
self.data = h5_file.get('data')
self.target = h5_file.get('label')
def __getitem__(self, index):
return (torch.from_numpy(self.data[index,:,:,:]).float(),
torch.from_numpy(self.target[index,:,:,:]).float())
def __len__(self):
return self.data.shape[0]