get_weights медленно работает с каждой итерацией - PullRequest
0 голосов
/ 26 октября 2019

Я вычисляю градиенты из частной сети и применяю их к другой основной сети. Затем я копирую веса для мастера в частную (это звучит избыточно, но терпите меня). Проблема в том, что с каждой итерацией get_weights становится медленнее, и мне даже не хватает памяти.

    def work(self, session):
        with session.as_default(), session.graph.as_default(): 
            self.private_net = ACNetwork()

            state = self.env.reset()

            while counter<TOTAL_TR_STEPS:

                action_index, action_vector = self.get_action(state)
                next_state, reward, done, info = self.env.step(action_index)
                ....# store the new data : reward, state etc...
                if done == True:
                    # end of episode
                    state = self.env.reset()
                    a_grads, c_grads = self.private_net.get_gradients()
                    self.master.update_from_gradients(a_grads, c_grads)
                    self._update_worker_net()  #this is the slow one
                !!!!!!

Это функция, которая использует get_weights.

def _update_worker_net(self):
      self.private_net.actor_t.set_weights(\
                               self.master.actor_t.get_weights())
      self.private_net.critic.set_weights(\
                               self.master.critic.get_weights())
return

Оглядываясь вокруг, я обнаружил пост, в котором предлагалось использовать

  K.clear_session()

в конце времениблок (в сегменте !!!!!!), потому что каким-то образом новые узлы добавляются (?!) в граф. Но этот единственный вернул ошибку:

AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph, exit the nesting and create a new graph.

Есть ли более быстрый способ передачи весов? Есть ли способ не добавлять новые узлы (если это действительно так?)

1 Ответ

2 голосов
/ 26 октября 2019

Это обычно происходит, когда вы динамически добавляете новые узлы на график. Пример ситуации:

while True:
    grad_op = optimizer.get_gradients()
    session.run([gradients])

Где get_gradients добавит новые операции в график. Операции, возвращаемые get_gradients, не будут меняться независимо от того, сколько раз вы вызываете его, поэтому одного вызова должно быть достаточно. Правильный способ переписать это будет:

grad_op = optimizer.get_gradients()
while True:
    session.run([gradients])

Возможно, что-то подобное происходит в вашем коде. Постарайтесь убедиться, что вы не создаете новые операции в цикле while.

...