Я пытаюсь реализовать алгоритм A3C в pytorch с некоторыми слоями свертки.Поэтому я запускаю несколько процессов в своей программе на Python, каждый из которых имеет локальную нейронную сеть.Сеть получает изображение, когда входной сигнал обрабатывается слоем свертки.Когда я запускаю эту программу на машине с Windows, у нее нет проблем и она работает нормально, но на машине с linux процесс каким-то образом блокируется, когда он пытается использовать сетевой метод forward.
Вот минимальный пример:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
state_shape = (3,10,10)
num_actions = 5
class Net(nn.Module):
def __init__(self, in_shape, num_actions):
super(Net, self).__init__()
self.in_shape = in_shape
self.num_actions = num_actions
self.features = nn.Sequential(
nn.Conv2d(in_shape[0], 16, kernel_size=3, stride=1),
nn.ReLU(),
nn.Conv2d(16, 16, kernel_size=3, stride=2),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(self.feature_size(), 256), # <-- problematic
nn.ReLU(),
)
def forward(self, x):
x = self.features(x)
def feature_size(self):
return self.features(torch.zeros(1, *self.in_shape)).view(1, -1).size(1)
class Worker(mp.Process):
def __init__(self):
super(Worker, self).__init__()
self.local_network = Net(state_shape, num_actions) # local network
def run(self):
s = np.zeros((3,10,10), dtype=np.float32)
for i in range(10):
print("{0}_before".format(i))
self.local_network.forward(torch.FloatTensor(s).unsqueeze(0))
print("{0}_after".format(i))
time.sleep(0.1)
if __name__ == "__main__":
# parallel training
workers = [Worker() for i in range(1)]
[w.start() for w in workers]
[w.join(timeout=10) for w in workers]
[w.terminate() for w in workers]
Я использую: torch 0.4.1 и torchvision 0.2.1
Кажется, что процесс инициализации каким-то образом неисправен.При инициализации линейного слоя с надписью выход свертки должен быть сплющен, чтобы знать размер входного сигнала линейного слоя.Поэтому
self.feature_size()
называется.Вычисление последовательности в процессе инициализации, кажется, вызывает проблему.К сожалению, я понятия не имею, почему.Кто-нибудь сталкивался с такой же проблемой?