Я пытаюсь закодировать Asynchronous Actor Criti c в PyTorch на основе этого репо: https://github.com/seungeunrho/minimalRL/blob/master/a3c.py, но я меняю класс ActorCriti c, чтобы использовать тот, который я закодировал сам .
В основном у меня есть класс A3 C, его экземпляр, global_model, с разделяемой памятью, и я использую torch.multiprocessing, чтобы открыть некоторые процессы для параллельного обучения модели. В каждом процессе в начале я должен создать новый экземпляр модели с именем local_model, чтобы продолжить обучение, но процесс застревает в инициализации локальной модели, даже если одна из глобальных моделей работает каждый время.
Пытаясь отладить его, я вижу, что он входит в функцию A3 C. init и SharedActorCriti c. init , но на этом он останавливается сразу после того, как поставил печать КПП. Однако, если я печатаю любое выражение, содержащее список (crit_param_gen), волшебным образом все работает. Я также заметил, что печать только «crit_param_gen» не годится.
Есть идеи, почему это происходит?
То же самое происходит, если я использую local_model = copy.deepcopy (global_model) как функцию create_local_model, т.е. работает только в том случае, если этот print присутствует.
В псевдокоде:
import torch.multiprocessiA3Cng as mp
import torch.nn as nn
import itertools as it
debug = True
A3C(nn.Module):
def __init__(self, model, n_features):
...
self.AC_architecture = SharedActorCritic(model, n_features)
class SharedActorCritic(nn.Module):
def __init__(self, model, n_features):
super(SharedActorCritic, self).__init__()
self.shared_architecture = model(n_features) # inherits from nn.Module
self.actor = SharedActor(n_features) # inherits from nn.Module
self.critic = SharedCritic(n_features) # inherits from nn.Module
self.critic_target = BaseCritic(model, n_features) # inherits from nn.Module
critic_param_gen = it.chain(self.shared_architecture.parameters(), self.critic.parameters())
print("checkpoint")
if debug: print(list(critic_param_gen)) # this makes the whole thing work
for trg_params, params in zip(self.critic_target.parameters(), critic_param_gen ):
trg_params.data.copy_(params.data)
def create_local_model(model, n_features):
local_model = A3C(model, n_features)
print("Process ended")
# in the main
global_model = Model() # works
global_model.share_memory() # doesn't really matter
p = mp.Process(target=create_local_model, args=(model, n_features, ))
p.start()
print("Process started")
p.join()
----
# output if debug is True
Process started
checkpoint
[ ...actual list of critic_param_gen ... ]
Process ended
# output if debug is False
Process started
checkpoint
# and then runs forever
Edit: разгадал загадку с оператором печати благодаря snakecharmerb. Я создал минимальный воспроизводимый пример. Кажется, что если сеть достаточно велика, операция копирования прерывается, если выполняется в процессе, но не за его пределами (поскольку глобальная модель может быть создана).
import torch.nn as nn
import torch.multiprocessing as mp
import copy
class Net(nn.Module):
def __init__(self, n_features=256, n_layers=8):
super(Net, self).__init__()
self.net1 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
self.net2 = nn.Sequential(*nn.ModuleList([nn.Linear(n_features, n_features) for _ in range(n_layers)]))
for p1, p2 in zip(self.net1.parameters(), self.net2.parameters()):
p1.data.copy_(p2.data)
def forward(self, x):
return self.net(x)
def create_local_model_v1(global_model):
local_model = copy.deepcopy(global_model)
print("Process ended")
%%time
global_model = Net(16,2)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()
# Output
Global model created
Process ended
Process started
CPU times: user 3 ms, sys: 11.9 ms, total: 14.9 ms
Wall time: 45.1 ms
%%time
global_model = Net(256,8)
print("Global model created")
p = mp.Process(target=create_local_model_v1, args=(global_model,))
p.start()
print("Process started")
p.join()
# Output - Gets stuck
Global model created
Process started