имя 'ModelIntervalCheckpoint' не определено Ошибка в обучении подкреплению kesas - PullRequest
0 голосов
/ 26 марта 2020

Я пытаюсь запустить глубоко усиленный учебный код в Google Co-lab. Когда я пытаюсь запустить этот код, он показывает мне эту ошибку:

def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks

ENV_NAME = 'CartPole-v0'
callbacks = build_callbacks('CartPole-v0')

Ошибка:

NameError                                 Traceback (most recent call last)
<ipython-input-13-0bea44654d9d> in <module>()
      7 
      8 ENV_NAME = 'CartPole-v0'
----> 9 callbacks = build_callbacks('CartPole-v0')

<ipython-input-13-0bea44654d9d> in build_callbacks(env_name)
      2     checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
      3     log_filename = 'dqn_{}_log.json'.format(env_name)
----> 4     callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
      5     callbacks += [FileLogger(log_filename, interval=100)]
      6     return callbacks

NameError: name 'ModelIntervalCheckpoint' is not defined

Пожалуйста, помогите мне решить эту проблему. Спасибо. Полный код:

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

model = Sequential()
model.add( Flatten(input_shape=(1,) + (4,)))
model.add(Activation('relu'))
model.add(Dense(2))
model.add(Activation('linear'))
print(model.summary())

policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)

dqn = DQNAgent(model=model, 
               nb_actions=2, 
               memory=memory, 
               nb_steps_warmup=10,
               target_model_update=1e-2, 
               policy=policy)

def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks

ENV_NAME = 'CartPole-v0'
callbacks = build_callbacks('CartPole-v0')
...