Я вычисляю градиенты из частной сети и применяю их к другой основной сети. Затем я копирую веса для мастера в частную (это звучит избыточно, но терпите меня). Проблема в том, что с каждой итерацией get_weights становится медленнее, и мне даже не хватает памяти.
def work(self, session):
with session.as_default(), session.graph.as_default():
self.private_net = ACNetwork()
state = self.env.reset()
while counter<TOTAL_TR_STEPS:
action_index, action_vector = self.get_action(state)
next_state, reward, done, info = self.env.step(action_index)
....# store the new data : reward, state etc...
if done == True:
# end of episode
state = self.env.reset()
a_grads, c_grads = self.private_net.get_gradients()
self.master.update_from_gradients(a_grads, c_grads)
self._update_worker_net() #this is the slow one
!!!!!!
Это функция, которая использует get_weights.
def _update_worker_net(self):
self.private_net.actor_t.set_weights(\
self.master.actor_t.get_weights())
self.private_net.critic.set_weights(\
self.master.critic.get_weights())
return
Оглядываясь вокруг, я обнаружил пост, в котором предлагалось использовать
K.clear_session()
в конце времениблок (в сегменте !!!!!!), потому что каким-то образом новые узлы добавляются (?!) в граф. Но этот единственный вернул ошибку:
AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph, exit the nesting and create a new graph.
Есть ли более быстрый способ передачи весов? Есть ли способ не добавлять новые узлы (если это действительно так?)