Как работать с большим набором данных в pytorch - PullRequest
0 голосов
/ 18 февраля 2019

У меня огромный набор данных, который не помещается в памяти (150G), и я ищу лучший способ работы с ним в Pytorch.Набор данных состоит из нескольких .npz файлов по 10 тыс. Выборок в каждом.Я пытался создать класс Dataset

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(self.path)
        self.file_length = {}
        for f in self.files:
            # Load file in as a nmap
            d = np.load(os.path.join(self.path, f), mmap_mode='r')
            self.file_length[f] = len(d['y'])

    def __len__(self):
        raise NotImplementedException()

    def __getitem__(self, idx):                
        # Find the file where idx belongs to
        count = 0
        f_key = ''
        local_idx = 0
        for k in self.file_length:
            if count < idx < count + self.file_length[k]:
                f_key = k
                local_idx = idx - count
                break
            else:
                count += self.file_length[k]
        # Open file as numpy.memmap
        d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
        # Actually fetch the data
        X = np.expand_dims(d['X'][local_idx], axis=1)
        y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
        return X, y

, но когда выборка действительно получена, это занимает более 30 секунд.Похоже, что все .npz открыто, хранится в оперативной памяти и имеет доступ к нужному индексу.Как быть более эффективным?

РЕДАКТИРОВАТЬ

Похоже, неправильно .npz файлы см. Сообщение , но есть ли лучшеподход?

ПРЕДЛОЖЕНИЕ РЕШЕНИЯ

Как предложено @covariantmonkey, lmdb может быть хорошим выбором.На данный момент, поскольку проблема связана с .npz файлами, а не с memmap, я перемоделировал свой набор данных, разделив файлы .npz пакетов на несколько .npy файлов.Теперь я могу использовать ту же логику, где memmap имеет смысл и действительно быстр (несколько мс для загрузки семпла).

1 Ответ

0 голосов
/ 19 февраля 2019

Насколько велики отдельные .npz файлы?Я был в таком же положении месяц назад.Различные форум сообщения, поиск в Google позже я пошел по маршруту lmdb .Вот что я сделал

  1. Разбейте большой набор данных на достаточно маленькие файлы, которые я могу поместить в gpu - каждый из них по сути является моей мини-партией.На этой стадии я не оптимизировал время загрузки , только память.
  2. создаю индекс lmdb с key = filename и data = np.savez_compressed(stff)

lmdb занимаетзабота о mmap для вас и безумно быстрая загрузка.

С уважением,
A

PS: savez_compessed требуется объект байта, чтобы вы могли сделать что-то вроде

output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb
...