Реализация дуэли DQN на TensorFlow 2.0 - PullRequest
1 голос
/ 28 февраля 2020

Я пытаюсь реализовать свой собственный Dueling DQN с использованием tenorflow 2 на основе https://arxiv.org/pdf/1511.06581.pdf. Я на самом деле тренирую его в среде Атлантиды, но не могу добиться хороших результатов ( Среднее вознаграждение за игру продолжает уменьшаться, в то время как Потери TD увеличиваются). Хотя я полагаю, что я получил логику c из статьи, я не знаю, исходит ли она от прямой реализации сети или выбранных параметров.

РЕДАКТИРОВАТЬ: Использование tf.keras.utils.plot_model дает мне это .

class DQNAgent:
  def __init__(self, state_shape, n_actions, epsilon=0):
    self.state_input = Input(shape=state_shape, name='State')
    self.x = Conv2D(16, (3, 3), strides=2, activation='relu')(self.state_input)
    self.x = Conv2D(32, (3, 3), strides=2, activation='relu')(self.x)
    self.x = Conv2D(64, (3, 3), strides=2, activation='relu')(self.x)
    self.x = Flatten()(self.x)
    self.x = Dense(256, activation='relu')(self.x)

    self.head_v = Dense(256,activation='relu')(self.x)
    self.head_v = Dense(1, activation='linear',name="Value")(self.head_v)
    self.head_v = RepeatVector(n_actions)(self.head_v)
    self.head_v = Flatten()(self.head_v)

    self.head_a = Dense(256,activation='relu')(self.x)
    self.head_a = Dense(n_actions, activation='linear',name='Activation')(self.head_a)

    self.m_head_a = RepeatVector(n_actions)(tf.keras.backend.mean(self.head_a,axis=1,keepdims=True))
    self.m_head_a = Flatten(name='meanActivation')(self.m_head_a)

    self.head_a = Subtract()([self.head_a,self.m_head_a])

    self.head_q = Add(name = "Q-value")([self.head_v,self.head_a])


    self.network = tf.keras.Model(inputs=[self.state_input], outputs=[self.head_q])
    self.weights = self.network.trainable_variables
    self.epsilon = epsilon
    self.optimizer = tf.keras.optimizers.Adam(1e-3)

  def get_qvalues(self, state_t):
    return self.network(state_t)

  def train(self, exp_replay, batch_size=64):
    states, actions, rewards, next_states, is_done = exp_replay.sample(batch_size)
    is_not_done = 1 - is_done

    with tf.GradientTape() as t:
      current_qvalues = agent.get_qvalues(states)
      current_action_qvalues = tf.reduce_sum(tf.one_hot(actions, n_actions) * current_qvalues, axis=-1)
      next_qvalues_target = target_network.get_qvalues(next_states)
      next_state_values_target = tf.reduce_max(next_qvalues_target, axis=-1)
      reference_qvalues = rewards + gamma*next_state_values_target*is_not_done
      td_loss = (current_action_qvalues - reference_qvalues)**2
      td_loss = tf.math.reduce_mean(td_loss)

    var_list = agent.weights
    grads = t.gradient(td_loss,var_list)
    self.optimizer.apply_gradients(zip(grads, var_list))
    return td_loss


  def sample_actions(self, qvalues):
    batch_size, n_actions = qvalues.shape
    random_actions = np.random.choice(n_actions, size=batch_size)
    best_actions = tf.math.argmax(qvalues, axis=-1)
    should_explore = np.random.choice([0, 1], batch_size, p = [1-self.epsilon, self.epsilon])
    return np.where(should_explore, random_actions, best_actions)


def load_weights_into_target_network(agent, target_network):
  for t, e in zip(target_network.network.trainable_variables, agent.network.trainable_variables):
    t.assign(e)

env = make_env() # Apply frame buffer on "AtlantisDeterministic-V4" env
env.reset()
n_actions = env.action_space.n
state_dim = env.observation_space.shape

agent = DQNAgent(state_dim, n_actions, epsilon=0.5)    
target_network = DQNAgent(state_dim, n_actions)

exp_replay = ReplayBuffer(10**5) # Random experience replay buffer
play_and_record(agent, env, exp_replay, n_steps=10000) # Plays exactly n_steps and records each transition in the ReplayBuffer
gamma = 0.99

for i in trange(10**5):
  play_and_record(agent, env, exp_replay, 10)

  td_loss = agent.train(exp_replay, 64)

  # adjust agent parameters
  if i % 500 == 0:
    load_weights_into_target_network(agent, target_network)
    agent.epsilon = max(agent.epsilon * 0.99, 0.01)

1 Ответ

0 голосов
/ 03 марта 2020

Проблема связана с тем, что целевая сеть не обновляется корректно, из-за ошибки программирования c. Тем не менее, он работал очень хорошо с данной модификацией буфера. Спасибо за вашу помощь.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...