Как читать из набора данных с высоким IO в pytorch, который растет от эпохи к эпохе - PullRequest
8 голосов
/ 07 февраля 2020

Я использую Tensorflow, но я пишу документацию для пользователей, которая обычно варьируется в зависимости от среды глубокого обучения .

При работе с наборами данных, которые не помещаются в локальную файловую систему ( TB +) Я собираю данные из удаленного хранилища данных и записываю их локально в стандартный формат Tensorflow tfrecords.

В первую эпоху обучения у меня будет только выборка нескольких значений, поэтому эпоха локальных данных очень мала, я тренируюсь на ней. В эпоху 2 я пересматриваю, какие файлы данных были созданы моими подпроцессами выборки (теперь больше), и обучаюсь расширенному набору локальных файлов данных для следующей эпохи. Повторите процесс каждую эпоху. Таким образом, я создаю локальный кеш сэмплов и могу вытеснять старые сэмплы, когда заполняю локальное хранилище. Кэш локальных сэмплов увеличивается примерно в то время, когда модель больше всего нуждается в дисперсии (ближе ко второй части обучения).

В Python / Tensorflow крайне важно, чтобы я не десериализовал данные в обучении Python. oop процесс, потому что Python GIL не может поддерживать скорости передачи данных (300-600 МБ / с c, данные являются необработанными c несжимаемыми), и, следовательно, производительность графического процессора снижается, когда Python GIL не может обслуживать обучение l oop быстро.

Запись сэмплов в файл tfrecords из подпроцессов (python многопроцессорная обработка) позволяет нативному тензорному потоку TFRecordsDataset выполнять десериализацию за пределами Python и, таким образом, мы обходим проблемы Python GIL, и я могу насытить GPU с высокой скоростью передачи данных ввода-вывода.

Я хотел бы знать, как бы я решил эту проблему в Pytorch. Я пишу об используемой стратегии выборки и хочу предоставить конкретные c рекомендации пользователям как Tensorflow, так и PyTorch, но я недостаточно хорошо знаю экосистему предварительной обработки PyTorch, чтобы писать достаточно подробно.

Примечание: единственное решение на основе Python для поддержки этих скоростей передачи данных может прийти в Python 3,8 с общей памятью System V и многопроцессорностью, но я не пробовал это все же, поскольку поддержка этого не совсем достаточна (скоро это будет). Существующих многопроцессорных решений недостаточно, поскольку они требуют десериализации в процессе обучения l oop и, таким образом, блокируют GIL во время десериализации при высоких скоростях ввода-вывода.

1 Ответ

7 голосов
/ 17 февраля 2020

На самом деле, вы можете легко десериализовать данные в подпроцессе, используя torch.utils.data.DataLoader. Установив для аргумента num_workers значение 1 или большее значение, вы можете порождать подпроцессы с их собственными python интерпретаторами и GIL.

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point

A Dataloader требует torch.utils.data.Dataset для получения данных. Это может быть не тривиальная задача для реализации надлежащего подкласса в вашем случае. Если вам нужно воссоздать экземпляр Dataset для каждой эпохи, вы можете сделать что-то вроде этого.

for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training

или даже лучше

dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training

В качестве примечания, обратите внимание что на работу с привязкой к ЦП влияет GIL в большинстве случаев, а не на операции ввода-вывода, т. е. threading подойдет для любой чисто тяжелой операции ввода-вывода, и вам даже не нужен subprocess. Для получения дополнительной информации, пожалуйста, обратитесь к этому вопросу и этой википедии статье .

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...