Конвертировать Pytorch DDPG в Tensorflow - PullRequest
0 голосов
/ 12 ноября 2018

Я нашел эту реализацию DDPG и хотел бы преобразовать ее в Tensorflow: https://github.com/higgsfield/RL-Adventure-2/blob/master/5.ddpg.ipynb

Я использую Eager Execution и у меня возникли некоторые проблемы с реализацией функции обновления ddpg.Возможно, я допустил некоторые ошибки, но не могу их найти

--- PYTORCH ---

def ddpg_update(batch_size, 
           gamma = 0.99,
           min_value=-np.inf,
           max_value=np.inf,
           soft_tau=1e-2):

    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state      = torch.FloatTensor(state).to(device)
    next_state = torch.FloatTensor(next_state).to(device)
    action     = torch.FloatTensor(action).to(device)
    reward     = torch.FloatTensor(reward).unsqueeze(1).to(device)
    done       = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)

    policy_loss = value_net(state, policy_net(state))
    policy_loss = -policy_loss.mean()

    next_action    = target_policy_net(next_state)
    target_value   = target_value_net(next_state, next_action.detach())
    expected_value = reward + (1.0 - done) * gamma * target_value
    expected_value = torch.clamp(expected_value, min_value, max_value)

    value = value_net(state, action)
    value_loss = value_criterion(value, expected_value.detach())


    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()

    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()

    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - soft_tau) + param.data * soft_tau
            )

    for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - soft_tau) + param.data * soft_tau
            )

--- TENSORFLOW ---

def ddpg_update(batch_size, 
           gamma = 0.99,
           min_value=-np.inf,
           max_value=np.inf,
           soft_tau=1e-2):

    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state = np.reshape(state, (batch_size, state_dim))
    next_state = np.reshape(next_state, (batch_size, state_dim))
    action = np.reshape(action, (batch_size, action_dim))
    done = np.reshape(done, (batch_size, 1))


    t_state = tf.convert_to_tensor(state, dtype=tf.float32)
    t_action = tf.convert_to_tensor(action, dtype=tf.float32)
    t_reward = tf.convert_to_tensor(reward, dtype=tf.float32)
    t_next_state = tf.convert_to_tensor(next_state, dtype=tf.float32)
    t_done = tf.convert_to_tensor(done, dtype=tf.float32)

    with tf.GradientTape(persistent=True) as tape:
        policy_loss = tf.reduce_mean(value_net.predict(t_state, policy_net.predict(t_state)))

        t_next_action = target_policy_net.predict(t_next_state)
        t_target_value = target_value_net.predict(t_next_state, t_next_action)
        expected_value = t_reward + (1.0 - t_done) * gamma * t_target_value
        expected_value = tf.clip_by_value(expected_value, tf.constant(min_value), tf.constant(max_value))

        value = value_net.predict(t_state, t_action)
        value_loss = value_criterion(value, expected_value)

    policy_grads = tape.gradient(policy_loss, policy_net.variables)
    value_grads = tape.gradient(value_loss, value_net.variables)
    policy_optimizer.apply_gradients(zip(policy_grads, policy_net.variables))
    value_optimizer.apply_gradients(zip(value_grads, value_net.variables))

    #update value target
    for x in range(len(value_net.variables)):
        #targe = (1-tau)*target + tau*source
        #target = target - tau*(target-source)
        target_value_net.variables[x].assign_sub(soft_tau * (target_value_net.variables[x] - value_net.variables[x]))


    #update policy target
    for x in range(len(policy_net.variables)):
        target_policy_net.variables[x].assign_sub(soft_tau * (target_policy_net.variables[x] - policy_net.variables[x]))
...