Может ли DQNAgent с переопределением PrioritizedMemory? - PullRequest
0 голосов
/ 28 марта 2019

Я использую DQNAgent и PrioritizedMemory, чтобы тренироваться против среды, и за вознаграждение это может быть слишком подходящим, но может ли это действительно случиться, когда среда только показывает новые состояния или это ошибка памяти.

Мой код здесь ниже, просто для того, чтобы вы могли проверить, хотите ли вы, и сказать мне, что не так, если вы видите что-то:

lr = 1e-3
emb_size = 100
look_back = 10

# "Expert" (regular dqn) model architecture

inp = Input(shape=(10,))
emb = Embedding(input_dim=env.action_space.n+1, output_dim = emb_size)(inp) 
rnn = Bidirectional(LSTM(5))(emb)
out = Dense(env.action_space.n, activation='softmax')(rnn)
expert_model = Model(inputs = inp, outputs = out)
expert_model.compile(loss='mse',
              optimizer='adam',
              metrics=['acc'])

# memory
memory = PrioritizedMemory(limit=5000,  window_length=window_length)

# policy
policy = BoltzmannQPolicy()

# agent
dqn = DQNAgent(model=expert_model, nb_actions=env.action_space.n, policy=policy, memory=memory, 
               enable_double_dqn=False, enable_dueling_network=False, gamma=.9, batch_size = 1, #Doesnt work if I change the batch size
               target_model_update=1e-2, processor = RecoProcessor())

dqn.compile(Adam(lr), metrics=['mae'])

train = dqn.fit(env, log_interval = 1000,nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)

И мой RecoProcessor класс:

class RecoProcessor(Processor):
    def process_observation(self, observation):
        look_back = 10
        if observation is None:
            X=np.zeros(look_back)
        else:
            if len(observation)>look_back:
                observation = observation[-look_back]
            observation = np.array(observation)
            if len(observation.shape) == 2:
                X = observation[:,1]
            else:
                X = np.array([observation[1]])
        if len(X)<look_back:
            X = np.append(X,np.zeros(look_back-len(X)))
        return X

    def process_state_batch(self, batch):
#         CHECK. i THINK SOMETHING IS WRONG HERE
#         print(batch)
#         print(batch[0])
        return batch[0]

    def process_reward(self, reward):
        return reward

    def process_demo_data(self, demo_data):
        for step in demo_data:
            step[0] = self.process_observation(step[0])
            step[2] = self.process_reward(step[2])
        return demo_data
...