Загрузчик данных по умолчанию в Pytorch застревает в тренировочном наборе классификации больших изображений - PullRequest
1 голос
/ 11 февраля 2020

Я обучаю модели классификации изображений в Pytorch и использую их загрузчик данных по умолчанию для загрузки моих данных обучения. У меня очень большой обучающий набор данных, поэтому обычно пара тысяч образцов изображений на класс. Я тренировал модели с общим количеством изображений около 200 тыс. В прошлом. Однако я обнаружил, что, когда в общей сложности более миллиона изображений, загрузчик данных Pytorch застревает.

Я считаю, что код зависает, когда я вызываю datasets.ImageFolder(...). Когда я Ctrl- C, это последовательно вывод:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

Я думал, что где-то может быть тупик, однако на основе стекового вывода из Ctrl- C это не похоже его ждет на замке. Тогда я подумал, что загрузчик данных был медленным, потому что я пытался загрузить намного больше данных. Я позволил ему работать в течение приблизительно 2 дней, и это не делало никакого прогресса, и за прошлые 2 часа загрузки я проверил, объем использования ОЗУ остался прежним. Я также смог загрузить учебные наборы данных с более чем 200 000 изображений менее чем за пару часов в прошлом. Я также попытался обновить свою машину GCP, чтобы иметь 32 ядра, 4 графических процессора и более 100 ГБ оперативной памяти, однако, похоже, что после загрузки определенного объема памяти загрузчик данных просто зависает.

Я не совсем понимаю, как может зависать загрузчик данных при циклическом перемещении по каталогу, и я до сих пор не уверен, что он завис или просто слишком медленный. Есть ли способ изменить загрузчик данных Pytortch, чтобы можно было обрабатывать более 1 миллиона изображений для обучения? Любые предложения по отладке также приветствуются!

Спасибо!

1 Ответ

1 голос
/ 11 февраля 2020

Это не проблема с DataLoader, это проблема с torchvision.datasets.ImageFolder и с тем, как он работает (и почему он работает намного хуже, чем больше у вас данных).

Он висит на этой строке, как указывает ваша ошибка:

for root, _, fnames in sorted(os.walk(d)): 

Источник можно найти здесь .

Основная проблема заключается в том, что он сохраняет каждый path и соответствующий label в гиганте list, см. Код ниже (несколько вещей для краткости исключены):

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

Очевидно, что изображения будут содержать 1 миллион строк (также довольно длинных) и соответствующие int для классов, которые определенно являются много и зависит от ОЗУ и ЦП.

Вы можете создавать свои собственные наборы данных (при условии, что вы заранее изменяете имена своих изображений), поэтому память не будет занята dataset.

Настройка структуры данных

Структура вашей папки должна выглядеть следующим образом:

root
    class1
    class2
    class3
    ...

Использовать, сколько классов вам нужно / нужно.

Теперь каждый class должны иметь следующие данные:

class1
    0.png
    1.png
    2.png
    ...

Учитывая, что вы можно перейти к созданию наборов данных.

Создать наборы данных

Ниже torch.utils.data.Dataset использует PIL для открытия изображений, вы можете сделать это по-другому, хотя:

import os
import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
        self._data = pathlib.Path(root) / folder
        self.klass = klass
        self.extension = extension
        # Only calculate once how many files are in this folder
        # Could be passed as argument if you precalculate it somehow
        # e.g. ls | wc -l on Linux
        self._length = sum(1 for entry in os.listdir(self._data))

    def __len__(self):
        # No need to recalculate this value every time
        return self._length

    def __getitem__(self, index):
        # images always follow [0, n-1], so you access them directly
        return Image.open(self._data / "{}.{}".format(str(index), self.extension))

Теперь вы можете легко создавать свои наборы данных (структура папок, как указано выше:

root = "/path/to/root/with/images"
dataset = (
    ImageDataset(root, "class0", 0)
    + ImageDataset(root, "class1", 1)
    + ImageDataset(root, "class2", 2)
)

Вы можете добавить столько datasets с указанными классами, сколько вы будете sh, сделайте это в l oop или что-то еще.

Наконец, используйте torch.utils.data.DataLoader как обычно, например:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
...