Я пытаюсь запустить приложение параллельной обработки с multiprocessing
, где я передаю класс, содержащий модель нейронной сети с использованием кера. Тем не менее, я получаю ошибку рассола при передаче объекта через starmap
метод модуля multiprocessing
. Пример игрушки приведен ниже, где агент запускает 10 эпизодов cartpole
параллельно:
from multiprocessing import Pool
import itertools
import gym
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
class Agent:
def __init__(self, input_dim, hidden_dims, output_dim):
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.output_dim = output_dim
inputs = Input(shape=(self.input_dim,))
net = inputs
# a layer instance is callable on a tensor, and returns a tensor
for h_dim in self.hidden_dims:
net = Dense(h_dim, activation='relu')(net)
net = Dense(self.output_dim, activation='softmax')(net)
# This creates a model that includes
# the Input layer and three Dense layers
self.model = Model(inputs=inputs, outputs=net)
def act(self, state):
state = np.reshape(state, [1, self.input_dim])
action = np.argmax(self.model.predict(state))
return action
env = gym.make("CartPole-v1")
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
def run_agent(num_episode, agent):
state = env.reset()
reward_episode = []
while True:
env.render()
action = agent.act(state)
state_next, reward, terminal, info = env.step(action)
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
reward_episode.append(reward)
state = state_next
if terminal:
break
return sum(reward_episode)
def run_parallel(agent):
episodes = list(range(10));
args_to_func = []
for i in episodes:
args_to_func.append([i, agent])
reward_agent = []
with Pool(processes=4) as pool:
reward_agent = pool.starmap(run_agent, args_to_func)
pool.close()
pool.join()
print(reward_agent)
if __name__ == "__main__":
agent = Agent(observation_space, [32], action_space)
run_parallel(agent)
Ошибка приводится ниже:
Traceback (most recent call last):
File "example_ev_alg.py", line 63, in <module>
run_parallel(agent)
File "example_ev_alg.py", line 56, in run_parallel
reward_agent = pool.starmap(run_agent, args_to_func)
File "/usr/lib/python3.6/multiprocessing/pool.py", line 274, in starmap
return self._map_async(func, iterable, starmapstar, chunksize).get()
File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
raise self._value
File "/usr/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
put(task)
File "/usr/lib/python3.6/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/usr/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects