Я пытаюсь использовать keras-rl с OpenAIGym, и когда я устанавливаю среду как игру Cartpole, она работает. Дело в том, что, когда я пытаюсь установить среду как RecoGYM (представлена на в этой статье ), я не могу заставить ее работать. Вот мой код:
lr = 1e-3
window_length = 2
emb_size = 100
look_back = 10
# "Expert" (regular dqn) model architecture
expert_model = Sequential()
expert_model.add(Embedding(env.action_space.n, emb_size, input_length=look_back, mask_zero=True))
expert_model.add(LSTM(100))
expert_model.add(Dense(env.action_space.n, activation='softmax'))
# memory
memory = PrioritizedMemory(limit=5000, window_length=window_length)
# policy
policy = BoltzmannQPolicy()
# agent
dqn = DQNAgent(model=expert_model, nb_actions=env.action_space.n, policy=policy, memory=memory,
enable_double_dqn=False, enable_dueling_network=False, #gamma=.3,
target_model_update=1e-2, nb_steps_warmup=100)
dqn.compile(Adam(lr), metrics=['mae'])
env.reset()
train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)
Но только в этой среде я получаю эту ошибку:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-5-fe4e63a38263> in <module>
27
28 env.reset()
---> 29 train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)
30 np.savetxt(fichero_train_history,
31 np.array(train.history["episode_reward"]), delimiter=",")
c:\users\myuser\src\keras-rl\rl\core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
133 if self.processor is not None:
134 observation = self.processor.process_observation(observation)
--> 135 assert observation is not None
136
137 # Perform random starts at beginning of episode and do not record them into the experience.
AssertionError:
И среда определяется следующим образом:
from reco_gym import env_1_args
# you can overwrite environment arguments here:
env_1_args['random_seed'] = 1234
env_1_args['num_products'] = 42
ENV_NAME = "reco-gym-v1"
env = gym.make(ENV_NAME)
env.init_gym(env_1_args)
np.random.seed(123)
env.reset()