Можно ли рассчитать потери за пределами графика тензопотока?
Я пишу алгоритм критики актера, и когда я хочу применить градиенты, мне приходится обращаться к временному проходу.
На каждом шаге я вызываю прямой проход, чтобы получить свои логиты и значения:
# collect data
for _ in range(NUM_STEPS):
sample = self.model.get_sample(s)
pong_actions = self.get_action(sample)
s_, r, d, i = self.envs.step(pong_actions)
states.append(s)
actions.append(np.expand_dims(sample, axis=1))
rewards.append(np.expand_dims(np.float32(r), axis=1))
masks.append(np.expand_dims(np.float32(1 - d), axis=1))
# do some stuff....
И снова обновить сеть
def update_network(self, states, qvals, actions):
dict = {self.inputs: states, self.qvals: qvals, self.actions: actions}
loss, _ = self.sess.run([self.loss, self.train_op], feed_dict=dict)
return loss
Если я хочу выполнить шаг backprop, мне нужно снова вызвать прямой проход с собранными состояниями.
У меня уже есть все значения, необходимые для расчета потерь. Второй прямой проход - просто указать в сети весовые коэффициенты.
Это быстрее, если я вычислю потери из графика и вставлю в оптимизатор. Тогда мне больше не нужно звонить вперед.
Быстрее:
# collect data
for _ in range(NUM_STEPS):
dist, v = self.model.forward(s)
pong_actions, a = self.get_action(dist)
s_, r, d, i = self.envs.step(pong_actions)
log_probs.append(dist.log_prob(a))
entropies.append(dist.entropy())
values.append(v)
rewards.append(np.expand_dims(np.float32(r), axis=1))
masks.append(np.expand_dims(np.float32(1 - d), axis=1))
# do some stuff....
# Calculate loss
loss = ...
def update_network(self, loss):
dict = {self.loss: loss}
self.sess.run(self.train_op, feed_dict=dict)
Рабочий процесс:
состояние подачи> выполнить действие с помощью pred> состояние и действие сбора> состояние и действие подачи для расчета потерь> backprop
Лучший рабочий процесс:
состояние подачи> выполнить действие с пред.> расчет потери с пред.> backprop
Проблема со вторым рабочим процессом - ошибка ValueError: No gradients provided for any variable
, потому что отсутствует прямой проход