Я столкнулся с большой проблемой с реализацией в tenorflow 2 агента DDPG. Хотя обновление критической c сети ясное и простое (просто сделайте градиентный спуск по потере), обновление актера немного сложнее.
Это моя реализация "actor_update" функция:
def actor_train(self, minibatch):
s_batch, _, _, _, _ = minibatch
with tf.GradientTape() as tape1:
with tf.GradientTape() as tape2:
mu = self.actor_network(s_batch)
q = self.critic_network([s_batch, mu])
mu_grad = tape1.gradient(mu, self.actor_network.trainable_weights)
q_grad = tape2.gradient(q, self.actor_network.trainable_weights)
x = np.array(q_grad)*np.array(mu_grad)
x /= -len(minibatch)
self.actor_optimizer.apply_gradients(zip(x, self.actor_network.trainable_weights))
Как указано в статье, оптимизация представляет собой произведение двух градиентов: один - градиент функции Q по отношению к действиям, а другой - градиент функции субъекта по весам .
Начиная все сети с весами, взятыми с помощью равномерного распределения между -1e-3 и 1e-3, субъект, похоже, не обновляет его веса. Вместо этого, отображение результата критического анализа c (с использованием MountainCarContinous в качестве тестовой среды) показывает небольшую зависимость от данных.
Это код критического значения c для полноты:
def critic_train(self, minibatch):
s_batch, a_batch, r_batch, s_1_batch, t_batch = minibatch
mu_prime = np.array(self.actor_target_network(s_1_batch))
q_prime = self.critic_target_network([s_1_batch, mu_prime])
ys = r_batch + self.GAMMA * (1 - t_batch) * q_prime
with tf.GradientTape() as tape:
predicted_qs = self.critic_network([s_batch, a_batch])
loss = tf.keras.losses.MSE(ys, predicted_qs)
dloss = tape.gradient(loss, self.critic_network.trainable_weights)
self.critic_optimizer.apply_gradients(zip(dloss, self.critic_network.trainable_weights))
В качестве дополнения, актер, кажется, насыщается после победного эпизода. (Означает, что он застревает на +1 или -1 для каждого входа).
В чем проблема? Функция обновления работает правильно? Или это только проблема настройки гиперпараметров?
Это репо, кто-то хочет иметь лучшее представление о проблеме: Github repo