Несколько сетей PyTorch, работающих параллельно на разных процессорах - PullRequest
1 голос
/ 22 февраля 2020

Я пытаюсь, чтобы разные нейронные сети PyTorch работали параллельно на разных процессорах, но обнаружил, что это не приводит к ускорению по сравнению с последовательным их запуском.

Ниже приведен мой код это точно повторяет проблему. Если вы запустите этот код, он покажет, что при 2 процессах это занимает примерно вдвое больше времени, чем при его запуске с 1 процессом, но на самом деле это займет столько же времени.

import time
import torch.multiprocessing as mp
import gym
import numpy as np
import copy
import torch.nn as nn
import torch

class NN(nn.Module):
    def __init__(self, output_dim):
        nn.Module.__init__(self)
        self.fc1 = nn.Linear(4, 50)
        self.fc2 = nn.Linear(50, 500)
        self.fc3 = nn.Linear(500, 5000)
        self.fc4 = nn.Linear(5000, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x

def Worker(ix):
  print("Starting training for worker ", ix)
  env = gym.make('CartPole-v0')
  model = NN(2)
  for _ in range(2000):
    model(torch.Tensor(env.reset()))
  print("Finishing training for worker ", ix)

def overall_process(num_workers):
  workers = []
  for ix in range(num_workers):
    worker = mp.Process(target=Worker, args=(ix, ))
    workers.append(worker)
  [w.start() for w in workers]  
  for worker in workers:
    worker.join()

  print("Finished Training")  
  print(" ")

start = time.time()
overall_process(1)
print("Time taken: ", time.time() - start)
print(" ")

start = time.time()
overall_process(2)
print("Time taken: ", time.time() - start)

Кто-нибудь знает, почему это может быть происходит и как это исправить?

Я подумал, что это может быть потому, что сети PyTorch автоматически реализуют параллелизм ЦП в фоновом режиме, и поэтому я попытался добавить следующие 2 строки, но это не всегда решает проблему:

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

1 Ответ

0 голосов
/ 26 февраля 2020

Ответ заключается в установке torch.set_num_threads (1) в начале каждого рабочего процесса (а не в основном процессе), как описано здесь: https://discuss.pytorch.org/t/multiple-networks-running-in-parallel-on-different-cpus/70482

...