Параллелизм при обучении нейронных сетей может быть достигнут двумя способами.
- Параллелизм данных - разделите большой пакет на две части и выполните один и тот же набор операций, но по отдельности на двух разных графических процессорах соответственно
- Параллелизм модели - разделить вычисления и запустить их на разных графических процессорах
Как вы уже задали в этом вопросе, вы хотели бы разбить расчет, который относится ко второй категории. Не существует готовых способов достижения параллелизма модели. PyTorch предоставляет примитивы для параллельной обработки с использованием пакета torch.distributed
. В этом учебном пособии подробно рассматриваются детали пакета, и вы можете выработать подход для достижения необходимого параллелизма модели.
Однако, параллелизм модели может быть очень сложным для достижения. Общий способ заключается в параллелизме данных с torch.nn.DataParallel
или torch.nn.DistributedDataParallel
. В обоих методах вы будете запускать одну и ту же модель на двух разных графических процессорах, однако один огромный пакет будет разделен на две меньшие порции. Градиенты будут накапливаться на одном графическом процессоре, и происходит оптимизация. Оптимизация происходит на одном графическом процессоре Dataparallel
и параллельно на всех графических процессорах DistributedDataParallel
с помощью многопроцессорной обработки.
В вашем случае, если вы используете DataParallel
, вычисления все равно будут выполняться на двух разных графических процессорах. Если вы заметили дисбаланс в использовании GPU, это может быть связано с тем, как был разработан DataParallel
. Вы можете попробовать использовать DistributedDataParallel
, который является самым быстрым способом тренировки на нескольких графических процессорах в соответствии с docs .
Есть и другие способы обработки очень больших партий. Эта статья подробно описывает их, и я уверен, что это будет полезно. Несколько важных моментов:
- делать накопление градиента для больших партий
- Использовать DataParallel
- Если этого недостаточно, используйте DistributedDataParallel