Стабильные базовые показатели сохраняют модель PPO и переподготовку - PullRequest
0 голосов
/ 02 февраля 2020

Здравствуйте, я использую стабильный базовый пакет (https://stable-baselines.readthedocs.io/), в частности, я использую PPO2, и я не уверен, как правильно сохранить свою модель ... Я тренировался в течение 6 виртуальных дней и я получил средний доход около 300, потом решил, что этого недостаточно для меня, поэтому я тренировался с моделью еще 6 дней. Но когда я посмотрел статистику тренировок, второй результат тренировок в каждом эпизоде ​​начался около 30. Это говорит о том, что он не сохранил все параметры.

вот как я сохраняю использование пакета:

def make_env_init(env_id, rank, seed=0):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """

    def env_init():
        # Important: use a different seed for each environment
        env = gym.make(env_id, connection=blt.DIRECT)
        env.seed(seed + rank)
        return env

    set_global_seeds(seed)
    return env_init



envs = VecNormalize(SubprocVecEnv([make_env_init(f'envs:{env_name}', i) for i in range(processes)]), norm_reward=False)

if os.path.exists(folder / 'model_dump.zip'):
    model = PPO2.load(folder / 'model_dump.zip', envs, **ppo_kwards)
else:
    model = PPO2(MlpPolicy, envs, **ppo_kwards)

model.learn(total_timesteps=total_timesteps, callback=callback)
model.save(folder / 'model_dump.zip')

1 Ответ

0 голосов
/ 06 апреля 2020

Способ, которым вы сохранили модель, верен. Обучение не является монотонным процессом: оно может также показать гораздо худшие результаты после дальнейшего обучения.

Что вы можете сделать, прежде всего, это записать журналы прогресса:

model = PPO2(MlpPolicy, envs, tensorboard_log="./logs/progress_tensorboard/")

Чтобы увидеть журнал, запустите в терминале:

tensorboard --port 6004 --logdir ./logs/progress_tensorboard/

он даст вам ссылку на доску, которую вы затем сможете открыть в браузере (например, http://pc0259: 6004 / )

Во-вторых, вы можете делать снимки модели каждые X шагов:

from stable_baselines.common.callbacks import CheckpointCallback

checkpoint_callback = CheckpointCallback(save_freq=1e4, save_path='./model_checkpoints/')
model.learn(total_timesteps=total_timesteps, callback=[callback, checkpoint_callback])

Комбинируя ее с журналом, вы можете подобрать модель, которая показала наилучшие результаты!

...