Градиент политики в керасе предсказывает только одно действие - PullRequest
2 голосов
/ 29 марта 2019

У меня проблемы с алгоритмом REINFORCE в керасе с играми Atari.После раунда около 30 эпизодов сеть сходится к одному действию.Но тот же алгоритм работает с CartPole-v1 и сходится со средней наградой 495,0 после раунда 350 эпизодов.Почему проблемы с играми atari?Я не знаю, что я делаю не так с функцией потерь.Вот мой код:

Примеры политики

episode 0:
p = [0.15498623 0.16416906 0.15513565 0.18847148 0.16070205 0.17653547]
....

episode 30:
p = [0. 0. 0. 0. 1. 0.]
....

episode 40:
p = [0. 0. 0. 0. 1. 0.]
....

Сеть

from keras.layers import *
from keras.models import Model
from keras.optimizers import Adam
from keras.backend.tensorflow_backend import set_session
import tensorflow as tf
import os

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
set_session(sess)

class PGAtariNetwork:
    def __init__(self, state_space, action_space, lr):
            input = Input(shape=state_space, name='inputs')
            rewards = Input(shape=(1,), name='rewards')

            conv1 = Conv2D(filters=64, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='conv1')(input)
            conv2 = Conv2D(filters=128, kernel_size=(4, 4), strides=(2, 2), activation='relu', name='conv2')(conv1)
            conv3 = Conv2D(filters=256, kernel_size=(4, 4), strides=(2, 2), activation='relu', name='conv3')(conv2)

            flatten = Flatten()(conv3)

            fc1 = Dense(units=512, activation='relu', name='fc1')(flatten)
            fc2 = Dense(units=256, activation='relu', name='fc2')(fc1)

            p = Dense(units=action_space, activation='softmax')(fc2)

            def policy_loss(r):
                def pol_loss(y_true, y_pred):
                log_prob = K.log(K.sum(y_pred * y_true, axis=1, keepdims=True) + 1e-10)
                return -log_prob * K.stop_gradient(r)

            return pol_loss

            return pol_loss

            self.model = Model(inputs=[input, rewards], outputs=p)
            self.model.compile(loss=policy_loss(rewards), optimizer=Adam(lr=lr))
            self.model.summary()

        def predict(self, s):
            s = s[np.newaxis, :]
            return self.model.predict([s, np.array([1])])

        def update_model(self, target):
            self.model.set_weights(target.get_weights())

        def train(self, s, a, r):
            self.model.train_on_batch([s, r], a)

        def save_weights(self, path):
            self.model.save_weights(path)

        def load_weights(self, path):
            if os.path.isfile(path):
              self.model.load_weights(path)

Обучение

import gym
import numpy as np
import matplotlib.pyplot as plt

from atari_wrapper import *
from PG_Arari_Network import *


class Agent:
    def __init__(self, env):
        self.env = env
        self.state_space = env.observation_space.shape
        self.action_space = env.action_space.n

        # Hyperparameter
        self.gamma = 0.97
        self.lr = 0.001

        # Environment
        self.mean_stop_train = 19.5
        self.save_model_episode = 10
        self.show_policy = False

        # Network
        self.model_path = 'Pong.h5'
        self.model = PGAtariNetwork(self.state_space, self.action_space, self.lr)
        self.model.load_weights(self.model_path)

        # Lists
        self.rewards = []

    def train(self, episodes):
        for episode in range(episodes):
            time_steps = 0
            states, actions, episode_rewards = [], [], []
            episode_reward = 0
            s = self.env.reset()
            s = np.array(s)

            while True:
                time_steps += 1
                if episode % 10 == 0:
                    self.show_policy = True
                a = self.get_action(s)
                self.show_policy = False
                s_, r, d, i = self.env.step(a)
                s_ = np.array(s_)
                episode_reward += r

                action = np.zeros(self.action_space)
                action[a] = 1
                actions.append(action)
                states.append(s)
                episode_rewards.append(r)

                if d:
                    discounted_episode_rewards = self.discount_rewards(episode_rewards)
                    self.update_policy(states, actions, discounted_episode_rewards)

                    self.rewards.append(episode_reward)
                    mean_rewards = np.mean(self.rewards[-min(len(self.rewards), 10):])
                    print('Episode: {}\tReward: {}\tMean: {}\tSteps: {}'.format(
                    episode, episode_reward, mean_rewards, time_steps))
                    if mean_rewards >= self.mean_stop_train:
                        self.model.save_weights(self.model_path)
                        return
                    break
                s = s_

    def get_action(self, s):
        p = self.model.predict(s)[0]
        if self.show_policy:
            print(p)
        a = np.random.choice(self.action_space, p=p)
        return a

    def discount_rewards(self, episode_rewards):
        discounted_episode_rewards = np.zeros_like(episode_rewards)
        cumulative = 0.0
        for i in reversed(range(len(episode_rewards))):
            cumulative = cumulative * self.gamma + episode_rewards[i]
            discounted_episode_rewards[i] = cumulative

        mean = np.mean(discounted_episode_rewards)
        std = np.std(discounted_episode_rewards)
        discounted_episode_rewards = (discounted_episode_rewards - mean) / std
        return discounted_episode_rewards

    def update_policy(self, states, actions, rewards):
        s = np.array(states)
        a = np.vstack(np.array(actions))
        self.model.train(s, a, rewards)


if __name__ == '__main__':
    env = make_atari('PongNoFrameskip-v0')
    env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=True, scale=True)
    agent = Agent(env)
    agent.train(30000)

    plt.plot(range(len(agent.rewards)), agent.rewards, color='blue')
    plt.title('Atari Policy Gradient')
    plt.show()
...