Я попытался реализовать процесс самостоятельного генерирования данных для настольных игр, используя несколько процессоров для одновременной самостоятельной игры. Для родительского процесса я создал модель 4 NN для 30 процессоров (1 модель для 10 процессоров и 1 модель для обучения), каждая модель находится в разных графических процессорах (модель реализована в виде архитектуры, похожей на 20 блоков с сетью), псевдо-код выглядит следующим образом
nnet = NN(gpu_num=0)
nnet1 = NN(gpu_num=1)
nnet2 = NN(gpu_num=2)
nnet3 = NN(gpu_num=3)
for i in range(num_iteration):
nnet1.load_state_dict(nnet.state_dict())
nnet2.load_state_dict(nnet.state_dict())
nnet3.load_state_dict(nnet.state_dict())
samples = parallel_self_play()
nnet.train(samples)
parallel_self_play () реализовано следующим образом
pool = mp.Pool(processes=num_cpu) #30
for i in range(self.args.numEps):
results = []
if i % 3 == 0:
net = self.nnet1
elif i % 3 == 1:
net = self.nnet2
else:
net = self.nnet3
results.append(pool.apply_async(AsyncSelfPlay, args=(net))
# get results from results array then return it
return results
Мой код работает отлично с почти 100% использованием gpu в течение первой самостоятельной игры (менее 10 минут на итерацию), но после первой итерации (обучения), когда я загружал новые веса в nnet1-3, использование gpu никогда не достигает 80% снова (~ 30 минут - 1 час на итерацию). Я замечаю несколько вещей, пока бездельничаю со мной код
Эта модель включает в себя слои батчнорм, когда переключение модели в режим train () -> train -> switch to eval () приводит к тому, что самовоспроизведение (использование прямого прохода от модели) вообще не использует gpu.
Если он не переключается из eval () -> train () (тренировка в режиме eval), это приводит к снижению использования графического процессора (30-50%), но не полностью.
Если модели, которые не являются основными, не загружают веса от основной, при самостоятельном воспроизведении по-прежнему используется 100% графического процессора, поэтому я предполагаю, что что-то произошло в процессе обучения и изменило некоторые состояния в модель.
Это также происходит, когда используется только 8 процессоров - архитектура 1gpu и модель поезда на лету (без промежуточного).
Может ли кто-нибудь подсказать мне, как исправить мой код или как мне обучить мою модель?