Tensorflow рассчитать потери вне графика - PullRequest
0 голосов
/ 04 мая 2019

Можно ли рассчитать потери за пределами графика тензопотока? Я пишу алгоритм критики актера, и когда я хочу применить градиенты, мне приходится обращаться к временному проходу.

На каждом шаге я вызываю прямой проход, чтобы получить свои логиты и значения:

# collect data
for _ in range(NUM_STEPS):
   sample = self.model.get_sample(s)
   pong_actions = self.get_action(sample)
   s_, r, d, i = self.envs.step(pong_actions)
   states.append(s)
   actions.append(np.expand_dims(sample, axis=1))
   rewards.append(np.expand_dims(np.float32(r), axis=1))
   masks.append(np.expand_dims(np.float32(1 - d), axis=1))

   # do some stuff....

И снова обновить сеть

def update_network(self, states, qvals, actions):
    dict = {self.inputs: states, self.qvals: qvals, self.actions: actions}
    loss, _ = self.sess.run([self.loss, self.train_op], feed_dict=dict)
    return loss

Если я хочу выполнить шаг backprop, мне нужно снова вызвать прямой проход с собранными состояниями.

У меня уже есть все значения, необходимые для расчета потерь. Второй прямой проход - просто указать в сети весовые коэффициенты. Это быстрее, если я вычислю потери из графика и вставлю в оптимизатор. Тогда мне больше не нужно звонить вперед.

Быстрее:

# collect data
for _ in range(NUM_STEPS):
    dist, v = self.model.forward(s)
    pong_actions, a = self.get_action(dist)
    s_, r, d, i = self.envs.step(pong_actions)

    log_probs.append(dist.log_prob(a))
    entropies.append(dist.entropy())
    values.append(v)
    rewards.append(np.expand_dims(np.float32(r), axis=1))
    masks.append(np.expand_dims(np.float32(1 - d), axis=1))

   # do some stuff....

   # Calculate loss
   loss = ...
def update_network(self, loss):
    dict = {self.loss: loss}
    self.sess.run(self.train_op, feed_dict=dict)

Рабочий процесс: состояние подачи> выполнить действие с помощью pred> состояние и действие сбора> состояние и действие подачи для расчета потерь> backprop

Лучший рабочий процесс: состояние подачи> выполнить действие с пред.> расчет потери с пред.> backprop

Проблема со вторым рабочим процессом - ошибка ValueError: No gradients provided for any variable, потому что отсутствует прямой проход

...