Несоответствующий размер между batch_size и target_q_values ​​в DQN - PullRequest
0 голосов
/ 27 марта 2019

Может ли кто-нибудь ОБЪЯСНИТЬ мне, почему я получаю эту ошибку, когда пытаюсь приспособить своего агента к окружающей среде?Я видел, что это должно быть повторяющейся ошибкой, и вот объяснение этого.Я использую среду под названием RecoGYM (версия 1), и это мой код:

class RecoProcessor(Processor):
    def process_observation(self, observation):
        look_back = 10
        if observation is None:
            X=np.zeros(look_back)
        else:
            if len(observation)>look_back:
                observation = observation[-look_back]
            observation = np.array(observation)
            if len(observation.shape) == 2:
                X = observation[:,1]
            else:
                X = np.array([observation[1]])
        if len(X)<look_back:
            X = np.append(X,np.zeros(look_back-len(X)))
        return X

    def process_state_batch(self, batch):
        return batch[0]

    def process_reward(self, reward):
        return reward

    def process_demo_data(self, demo_data):
        for step in demo_data:
            step[0] = self.process_observation(step[0])
            step[2] = self.process_reward(step[2])
        return demo_data


lr = 1e-3
window_length = 1
emb_size = 100
look_back = 10

# "Expert" (regular dqn) model architecture

inp = Input(shape=(10,))
emb = Embedding(input_dim=env.action_space.n+1, output_dim = emb_size)(inp) 
rnn = Bidirectional(LSTM(5))(emb)
out = Dense(env.action_space.n, activation='softmax')(rnn)
expert_model = Model(inputs = inp, outputs = out)


# try using different optimizers and different optimizer configs
expert_model.compile(loss='mse',
              optimizer='adam',
              metrics=['acc'])

# memory
memory = PrioritizedMemory(limit=5000,  window_length=window_length)

# policy
policy = BoltzmannQPolicy()

# agent
dqn = DQNAgent(model=expert_model, nb_actions=env.action_space.n, policy=policy, memory=memory, 
               enable_double_dqn=False, enable_dueling_network=False, gamma=.6, 
               target_model_update=1e-2, processor = RecoProcessor())

dqn.compile(Adam(lr), metrics=['mae'])

train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)

И мой вывод выглядит так:

Training for 50000 steps ...
Interval 1 (0 steps performed)
  966/10000 [=>............................] - ETA: 25s - reward: 0.0155

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    <ipython-input-48-00ed6e6b7fff> in <module>
         54 dqn.compile(Adam(lr), metrics=['mae'])
         55 
    ---> 56 train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)
         57 np.savetxt(fichero_train_history, 
         58            np.array(train.history["episode_reward"]), delimiter=",")

    c:\users\angelo\src\keras-rl\rl\core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
        192                     # Force a terminal state.
        193                     done = True
    --> 194                 metrics = self.backward(reward, terminal=done)
        195                 episode_reward += reward
        196 

    c:\users\angelo\src\keras-rl\rl\agents\dqn.py in backward(self, reward, terminal)
        330                 # outlined in Mnih (2015). In short: it makes the algorithm more stable.
        331                 target_q_values = self.target_model.predict_on_batch(state1_batch)
    --> 332                 assert target_q_values.shape == (self.batch_size, self.nb_actions)
        333                 q_batch = np.max(target_q_values, axis=1).flatten()
        334             assert q_batch.shape == (self.batch_size,)

    AssertionError:

Большая часть класса RecoProcessor сделан методом проб и ошибок, и я думаю, что ключ находится в этом классе.У меня также есть проблемы с переменными look_back и window_length, потому что я не очень хорошо понимаю разницу между ними.

(ОБНОВЛЕНИЕ 1) Я только что проверил, что переменные assert имеют этиформы:

Target_Q_values: (1, 42)

(self.batch_size, self.nb_actions): (32, 42)

...