Прежде всего, рекомендуется использовать вместо torch.nn.parallel.DistributedDataParallel
.
Вы можете проверить torch.nn.DataParallel
документацию, где описан процесс (вы также можете проверить исходный код и немного глубже покопаться в github , здесь как выполняется репликация модуля).
Вот примерно то, как это делается:
Инициализация
Все (или выбранные) идентификаторы устройств сохраняются в конструкторе и измерении, по которым будут разбросаны данные (почти всегда 0
означает быть разделено на устройства по пакету)
Вперед
Это делается во время каждого forward
прогона:
- Входы разбросаны (тензоры по размерам,
tuple
, list
, dict
неглубоко скопировано, другие данные распределяются между потоками). - Если есть только одно устройство, просто верните
module(*args, **kwargs)
- Если есть несколько устройств, скопируйте сеть с исходной машины на другие устройства (это делается каждый раз!)
- Вперед на каждом устройстве с соответствующим входом
- Собирать выходные данные с устройств на одно исходное устройство (конкатенация выходов) на исходный компьютер.
- Выполнить остальную часть кода, выполнить обратную передачу, обновить веса на Исходная машина et c.
Исходная машина - это cuda:0
по умолчанию, но ее можно выбрать. Также веса обновляются для одного device
, только пакет разбрасывается, а выходы собираются.