Ошибка Pickle при передаче объектов класса с моделями keras через многопроцессорный модуль - PullRequest
0 голосов
/ 16 марта 2020

Я пытаюсь запустить приложение параллельной обработки с multiprocessing, где я передаю класс, содержащий модель нейронной сети с использованием кера. Тем не менее, я получаю ошибку рассола при передаче объекта через starmap метод модуля multiprocessing. Пример игрушки приведен ниже, где агент запускает 10 эпизодов cartpole параллельно:

from multiprocessing import Pool
import itertools
import gym
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

class Agent:
    def __init__(self, input_dim, hidden_dims, output_dim):
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        inputs = Input(shape=(self.input_dim,))

        net = inputs
        # a layer instance is callable on a tensor, and returns a tensor
        for h_dim in self.hidden_dims:
            net = Dense(h_dim, activation='relu')(net)

        net = Dense(self.output_dim, activation='softmax')(net)
        # This creates a model that includes
        # the Input layer and three Dense layers
        self.model = Model(inputs=inputs, outputs=net)

    def act(self, state):
        state = np.reshape(state, [1, self.input_dim])
        action = np.argmax(self.model.predict(state))
        return action

env = gym.make("CartPole-v1")
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n

def run_agent(num_episode, agent):
    state = env.reset()
    reward_episode = []
    while True:
        env.render()
        action = agent.act(state)
        state_next, reward, terminal, info = env.step(action)
        reward = reward if not terminal else -reward
        state_next = np.reshape(state_next, [1, observation_space])
        reward_episode.append(reward)
        state = state_next
        if terminal:
            break
    return sum(reward_episode)

def run_parallel(agent):
    episodes = list(range(10));
    args_to_func = []
    for i in episodes:
        args_to_func.append([i, agent])

    reward_agent = []
    with Pool(processes=4) as pool:
        reward_agent = pool.starmap(run_agent, args_to_func)
        pool.close()
        pool.join()
    print(reward_agent)

if __name__ == "__main__":
    agent = Agent(observation_space, [32], action_space)
    run_parallel(agent)

Ошибка приводится ниже:

Traceback (most recent call last):
  File "example_ev_alg.py", line 63, in <module>
    run_parallel(agent)
  File "example_ev_alg.py", line 56, in run_parallel
    reward_agent = pool.starmap(run_agent, args_to_func)
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 274, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
    raise self._value
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
    put(task)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/usr/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...