Я пытаюсь запустить глубоко усиленный учебный код в Google Co-lab. Когда я пытаюсь запустить этот код, он показывает мне эту ошибку:
def build_callbacks(env_name):
checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
log_filename = 'dqn_{}_log.json'.format(env_name)
callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
callbacks += [FileLogger(log_filename, interval=100)]
return callbacks
ENV_NAME = 'CartPole-v0'
callbacks = build_callbacks('CartPole-v0')
Ошибка:
NameError Traceback (most recent call last)
<ipython-input-13-0bea44654d9d> in <module>()
7
8 ENV_NAME = 'CartPole-v0'
----> 9 callbacks = build_callbacks('CartPole-v0')
<ipython-input-13-0bea44654d9d> in build_callbacks(env_name)
2 checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
3 log_filename = 'dqn_{}_log.json'.format(env_name)
----> 4 callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
5 callbacks += [FileLogger(log_filename, interval=100)]
6 return callbacks
NameError: name 'ModelIntervalCheckpoint' is not defined
Пожалуйста, помогите мне решить эту проблему. Спасибо. Полный код:
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory
model = Sequential()
model.add( Flatten(input_shape=(1,) + (4,)))
model.add(Activation('relu'))
model.add(Dense(2))
model.add(Activation('linear'))
print(model.summary())
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model,
nb_actions=2,
memory=memory,
nb_steps_warmup=10,
target_model_update=1e-2,
policy=policy)
def build_callbacks(env_name):
checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
log_filename = 'dqn_{}_log.json'.format(env_name)
callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
callbacks += [FileLogger(log_filename, interval=100)]
return callbacks
ENV_NAME = 'CartPole-v0'
callbacks = build_callbacks('CartPole-v0')