Как правильно настроить пресеты тренера - PullRequest
0 голосов
/ 23 сентября 2019

Я пытался понять, как работают механизмы обучения с подкреплением в AWS.Я недавно перешел на платформу COACH после многочисленных проблем с версионированием во время работы с RAY.Я до сих пор не могу понять, как правильно настроить пресеты.Тренировочные циклы иногда продолжаются вечно и не прекращаются, когда я этого ожидаю.Я также не уверен, как определить количество шагов в эпизоде, чтобы модель не продолжала тренироваться.

Награда на изображении здесь доходит до 3,5 миллионов, что мне не нужно,И, как вы можете видеть, очень нестабильно

Я попытался возиться с парой предустановленных настроек, особенно для алгоритма DQN.Я изменил следующие параметры


schedule_params.improve_steps = TrainingSteps(100000) #between 100 and 1000000
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(100) # between 10 and 100
schedule_params.evaluation_steps = EnvironmentEpisodes(10) #between 1 and 10
schedule_params.heatup_steps = EnvironmentSteps(10) #between 10 and 100

Это предварительная установка для DQN:

from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType, EmbedderScheme
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.schedules import ConstantSchedule

from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import GymVectorEnvironment
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.memories.memory import MemoryGranularity
from rl_coach.schedules import LinearSchedule
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
from rl_coach.filters.observation.observation_move_axis_filter import ObservationMoveAxisFilter

from rl_coach.architectures.layers import Dense

####################
# Graph Scheduling #
####################

schedule_params = ScheduleParameters()
schedule_params.improve_steps = TrainingSteps(100000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(100)
schedule_params.evaluation_steps = EnvironmentEpisodes(10)
schedule_params.heatup_steps = EnvironmentSteps(10)
#########
# Agent #
#########
agent_params = DQNAgentParameters()

# DQN params
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
agent_params.algorithm.discount = 0.99
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)


# NN configuration
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
# agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense(1)]
agent_params.network_wrappers['main'].batch_size = 64
# agent_params.pre_network_filter.add_observation_filter('observation', 'move_axis',
#     ObservationMoveAxisFilter(0,0))
# agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
#     ObservationNormalizationFilter(name='normalize_observation'))

# ER size
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)

# E-Greedy schedule
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)

################
#  Environment #
################
env_params = GymVectorEnvironment(level='env:ArrivalSim')
env_params.additional_simulator_parameters = {'price': 30.0 }
# env_params.observation_space_type = ObservationSpaceType
#################
# Visualization #
#################

vis_params = VisualizationParameters()
vis_params.dump_gifs = False

########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = False
preset_validation_params.min_reward_threshold = 8000
preset_validation_params.max_episodes_to_achieve_reward = 250

graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
                                    schedule_params=schedule_params, vis_params=vis_params,
                                    preset_validation_params=preset_validation_params)

проблема главным образом связана с планированием графика.

Я ожидаю, что смогуустановить тренировочный цикл, который имеет фиксированное количество шагов в каждом эпизоде ​​и не продолжается до бесконечности.Я также надеюсь контролировать количество эпизодов.

...