ValueError: Ошибка формы ввода с агентом Keras DQN - PullRequest
1 голос
/ 02 апреля 2020

У меня небольшая ошибка при использовании агента DQN RL от Keras. Я создал свою собственную среду гимнастики OpenAI, которая выводит массив numpy размера 1 для наблюдения. Тем не менее, когда я вызываю функцию подгонки в моей среде, я получаю ошибку ValueError: Error when checking input: expected flatten_3_input to have shape (1, 1) but got array with shape (1, 4). Я использовал тот же код (изменяя только форму ввода на (1,4)) в среде CartPole без ошибок, поэтому я очень смущен тем, в чем здесь проблема. На каждом этапе среда моего тренажерного зала возвращает кортеж в форме (numpy array, float, bool, dict), того же формата, что и CartPole. Мои политики и целевые сети имеют вид:

def agent(shape, actions):
    model = Sequential()
    model.add(Flatten(input_shape = (1, shape)))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(actions, activation='linear'))
    return model

Следующее выдает ошибку в функции подбора:

model = agent(1, len(env.action_space))
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannGumbelQPolicy()

dqn = DQNAgent(model=model, policy=policy, nb_actions=len(env.action_space), memory=memory)
dqn.compile('adam', metrics = ['mae'])
dqn.fit(env, nb_steps = 50000, visualize = False, verbose = 1)

Я читаю ответы на похожие проблемы на модели Keras : Ошибка измерения формы ввода для агента RL , но я не могу решить эту проблему. Любые предложения здесь? Спасибо!

...