Параметр копирования PyTorch застревает в многопроцессорной обработке, если параметры слишком велики - PullRequest
2 голосов
/ 07 мая 2020

Я пытаюсь закодировать Asynchronous Actor Criti c в PyTorch на основе этого репо: https://github.com/seungeunrho/minimalRL/blob/master/a3c.py, но я меняю класс ActorCriti c, чтобы использовать тот, который я закодировал сам .

В основном у меня есть класс A3 C, его экземпляр, global_model, с разделяемой памятью, и я использую torch.multiprocessing, чтобы открыть некоторые процессы для параллельного обучения модели. В каждом процессе в начале я должен создать новый экземпляр модели с именем local_model, чтобы продолжить обучение, но процесс застревает в инициализации локальной модели, даже если одна из глобальных моделей работает каждый время.

Пытаясь отладить его, я вижу, что он входит в функцию A3 C. init и SharedActorCriti c. init , но на этом он останавливается сразу после того, как поставил печать КПП. Однако, если я печатаю любое выражение, содержащее список (crit_param_gen), волшебным образом все работает. Я также заметил, что печать только «crit_param_gen» не годится.

Есть идеи, почему это происходит?

То же самое происходит, если я использую local_model = copy.deepcopy (global_model) как функцию create_local_model, т.е. работает только в том случае, если этот print присутствует.

В псевдокоде:

import torch.multiprocessiA3Cng as mp
import torch.nn as nn
import itertools as it

debug = True

A3C(nn.Module):
  def __init__(self, model, n_features):
     ... 
     self.AC_architecture = SharedActorCritic(model, n_features)

class SharedActorCritic(nn.Module):
    def __init__(self, model, n_features):
        super(SharedActorCritic, self).__init__()

        self.shared_architecture = model(n_features) # inherits from nn.Module
        self.actor = SharedActor(n_features) # inherits from nn.Module
        self.critic = SharedCritic(n_features) # inherits from nn.Module

        self.critic_target = BaseCritic(model, n_features) # inherits from nn.Module

        critic_param_gen = it.chain(self.shared_architecture.parameters(), self.critic.parameters())
        print("checkpoint")
        if debug: print(list(critic_param_gen)) # this makes the whole thing work
        for trg_params, params in zip(self.critic_target.parameters(), critic_param_gen ):
            trg_params.data.copy_(params.data)

def create_local_model(model, n_features):
    local_model = A3C(model, n_features)
    print("Process ended")

# in the main
global_model = Model() # works
global_model.share_memory() # doesn't really matter

p = mp.Process(target=create_local_model, args=(model, n_features, ))
p.start()
print("Process started")
p.join()

----
# output if debug is True
Process started
checkpoint
[ ...actual list of critic_param_gen ... ]
Process ended

# output if debug is False
Process started
checkpoint
# and then runs forever

Edit: разгадал загадку с оператором печати благодаря snakecharmerb. Я создал минимальный воспроизводимый пример. Кажется, что если сеть достаточно велика, операция копирования прерывается, если выполняется в процессе, но не за его пределами (поскольку глобальная модель может быть создана).

import torch.nn as nn
import torch.multiprocessing as mp
import copy

class Net(nn.Module):
    def __init__(self, n_features=256, n_layers=8):
        super(Net, self).__init__()
        self.net1 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
        self.net2 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))

        for p1, p2 in zip(self.net1.parameters(), self.net2.parameters()):
            p1.data.copy_(p2.data)

    def forward(self, x):
        return self.net(x)

def create_local_model_v1(global_model):
    local_model = copy.deepcopy(global_model)
    print("Process ended")

%%time
global_model = Net(16,2)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()

# Output
Global model created
Process ended
Process started
CPU times: user 3 ms, sys: 11.9 ms, total: 14.9 ms
Wall time: 45.1 ms

%%time
global_model = Net(256,8)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()

# Output - Gets stuck
Global model created
Process started


1 Ответ

0 голосов
/ 30 июля 2020

TL; DR: используйте torch.multiprocessing.spawn

Я недостаточно опытен, чтобы определить точную причину и решение этой ошибки, но проблема возникает на данный момент в torch/nn/parameter.py:

result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)

Это вызывается во время процесса глубокого копирования. Чтобы исследовать немного больше, я провел несколько более подробный эксперимент, чтобы проверить, какие параметры и среды вызывают зависание. Суть результатов в том, что проблема не в размере модели, а в том, сколько функций / проблем может вызвать проблемы. У меня 256 функций вызывают зависание вне зависимости от того, сколько слоев. Другая более любопытная проблема заключается в том, что когда я удаляю часть инициализации, где параметры из net1 копируются в net2, зависание исчезает, однако если я ничего не отправляю другому процессу, все работает нормально. Наконец, при использовании функции spawn все работает нормально, пока количество слоев не превысит 256.

Мне нужно предупредить все о зависании, насколько я могу судить, это тупик, но он может быть просто какой-то чрезвычайно медленный процесс. Это очень маловероятно, потому что кажется, что вся активность прекращается, однако я не мог подтвердить, что это тупик, потому что, когда я пошел на трассировку кода C во время зависания, все, что я получил, был адрес памяти (на самом деле Подтвердите все, что я думаю, мне нужно восстановить torch с некоторыми параметрами отладки ...). В любом случае, я примерно на 99% уверен, что это тупик, вероятно, из-за чего-то в многопроцессорной обработке. Моя уверенность настолько высока, что код даже не реагирует на сигналы. Если бы все работало так, как ожидалось, я бы ожидал, что программа хотя бы позволит мне распечатать трассировку от обработчика сигнала, но ничего.

Я нашел следующее сообщение в блоге несколько приятным: Tragi c история тупиковой Python очереди

Помимо этого, мое мнение на данный момент - это чертовски комбинация torch и multiprocessing.

Если кому-то интересно дайте мне знать, чтобы увидеть код экспериментов, которые я проводил, или их результат.

...