Когда вы используете DistributedDataParallel
, у вас есть одна и та же модель на нескольких устройствах, которые синхронизируются, чтобы иметь точно такие же параметры.
При использовании DDP одна оптимизация - до сохранить модель только в одном процессе , а затем загрузить ее во все процессы, уменьшив накладные расходы на запись.
Поскольку они идентичны, нет необходимости сохранять модели из всех процессов, так как это просто запишет одни и те же параметры несколько раз. Например, если у вас есть 4 процесса / графического процессора, вы должны записать один и тот же файл 4 раза, а не один раз. Этого можно избежать, сохранив его только из основного процесса.
Это оптимизация для сохранения модели. Если вы загружаете модель сразу после того, как сохранили ее, будьте осторожны.
Если вы используете эту оптимизацию, убедитесь, что все процессы не начнут загружаться до завершения сохранения.
Если вы сохраните его только в одном процессе, этому процессу потребуется время для записи файла. Тем временем все другие процессы продолжаются, и они могут загрузить файл до того, как он был полностью записан на диск, что может привести к всевозможным неожиданным действиям или сбоям, независимо от того, не существует ли этот файл еще, вы пытаетесь прочитать неполный файл или вы загружаете старую версию модели (если вы перезаписываете тот же файл).
Кроме того, при загрузке модуля вам необходимо предоставить соответствующий аргумент map_location
, чтобы предотвратить переход процесса в чужие устройства. Если map_location
отсутствует, torch.load
сначала загрузит модуль в ЦП, а затем скопирует каждый параметр туда, где он был сохранен, что приведет к тому, что все процессы на одной машине будут использовать один и тот же набор устройств .
При сохранении параметров (или любого тензора в этом отношении) PyTorch включает устройство, на котором он был сохранен. Допустим, вы сохраняете ее из процесса, который использовал GPU 0 (device = "cuda:0"
), эта информация сохраняется, и когда вы ее загружаете, параметры автоматически переносятся на это устройство. Но если вы загрузите его в процессе, который использует GPU 1 (device = "cuda:1"
), вы неправильно загрузите их в "cuda:0"
. Теперь вместо использования нескольких графических процессоров у вас есть одна и та же модель несколько раз в одном графическом процессоре. Скорее всего, у вас закончится память, но даже если вы этого не сделаете, вы больше не будете использовать другие графические процессоры.
Чтобы избежать этой проблемы, вы должны установить соответствующее устройство для map_location
из torch.load
.
torch.load(PATH, map_location="cuda:1")
# Or load it on the CPU and later use .to(device) on the model
torch.load(PATH, map_location="cpu")