Я пытаюсь реализовать DDPG для маятниковой среды OpenAI, используя Tensorflow и Keras. После нескольких итераций изучения градиенты вывода модели актора по отношению к параметрам актора, кажется, всегда стремятся к нулю, в результате чего модель перестает изучать что-либо. Мой код основан на нескольких общедоступных проектах, в которых используется аналогичная настройка и обучение сети, и все, кажется, работают нормально. Основная функция для обучения в приведенном ниже коде - это функция критика train (), которая вызывает функцию поезда актера внутри.
class CriticNetwork:
def __init__(self, sess, state_size, action_size, actor, lr=0.001):
self.sess = sess
self.learning_rate = lr
self.tau = 0.001
self.state_size = state_size
self.action_size = action_size
self.model, self.state, self.action = self.create_network()
self.target_model, self.target_state, self.target_action = self.create_network()
self.action_grads = K.gradients(self.model.output, self.action)
self.gamma = 0.99
self.actor = actor
self.sess.run(tf.initialize_all_variables())
def create_network(self):
S = Input(shape=self.state_size)
A = Input(shape=self.action_size)
s1 = Dense(400, activation='relu')(S)
h1 = Concatenate()([s1,A])
h2 = Dense(300, activation='relu')(h1)
output = Dense(1, activation='linear')(h2)
model = Model(inputs=[S,A], outputs=output)
adam = Adam(lr=self.learning_rate)
model.compile(loss='mse', optimizer=adam)
return model, S, A
def gradients(self, states, actions):
return self.sess.run(self.action_grads, feed_dict={
self.state : states,
self.action : actions
})[0]
def update_target_model(self):
critic_weights = np.array(self.model.get_weights())
critic_target_weights = np.array(self.target_model.get_weights())
critic_target_weights = self.tau * critic_weights + (1 - self.tau) * critic_target_weights
self.target_model.set_weights(critic_target_weights)
def train(self, batch):
states = batch[0]
actions = batch[1]
next_states = batch[2]
rewards = batch[3]
y = rewards + self.gamma*self.target_predict(next_states, self.actor.target_act(next_states))
self.model.fit([states, actions], y, verbose=0)
a_for_grad = self.actor.model.predict(states)
grads = self.gradients(states, a_for_grad)
self.actor.train(states, grads)
self.update_target_model()
self.actor.update_target_model()
def predict(self, states, actions):
return self.model.predict([states, actions])
def target_predict(self, states, actions):
return self.target_model.predict([states, actions])
class ActorNetwork:
def __init__(self, sess, state_size, action_size, lr=0.001):
self.sess = sess
self.learning_rate = lr
self.tau = 0.001
self.state_size = state_size
self.action_size = action_size
self.model, self.state = self.create_network()
self.target_model, self.target_states = self.create_network()
self.weights = self.model.trainable_weights
self.action_gradient = tf.placeholder(tf.float32, [None, action_size])
self.params_grad = tf.gradients(self.model.output, self.weights, -self.action_gradient)
grads = zip(self.params_grad, self.weights)
self.optimize = Adam(lr).apply_gradients(grads)
self.sess.run(tf.initialize_all_variables())
self.gamma = 0.99
def create_network(self):
S = Input(shape=self.state_size)
h1 = Dense(400, activation='relu')(S)
h2 = Dense(300, activation='relu')(h1)
output = Dense(1, activation='tanh')(h2)
model = Model(inputs=S, outputs=output)
return model, S
def update_target_model(self):
actor_weights = np.array(self.model.get_weights())
actor_target_weights = np.array(self.target_model.get_weights())
actor_target_weights = self.tau * actor_weights + (1 - self.tau) * actor_target_weights
self.target_model.set_weights(actor_target_weights)
def train(self, states, action_grads):
self.sess.run(self.optimize, feed_dict={
self.state: states,
self.action_gradient: action_grads
})
def act(self, states):
return self.model.predict(states)
def target_act(self, states):
return self.target_model.predict(states)
У кого-нибудь есть идеи, почему градиенты актера всегда становятся нулями? Я должен также упомянуть, что сам вывод не равен нулю, поэтому я не думаю, что все значения ReLU идут ниже нуля.