Я недавно узнал о Q-Learning на примере среды Gym "CartPole-v1".
Функция предсказания указанной модели всегда возвращает вектор, который выглядит как [[ 0.31341377 -0.03776223]]
. Я создал свою собственную маленькую игру, в которой Ai должен двигаться влево или вправо с выходными значениями 0 и 1. Я просто показываю список [0, 0, 1, 0, 0]
в сеть, если он выдает 0, он уходит влево, если он выдает 1, он идет вправо. Достигните левого 0 и вы выиграете, правого 0 и проиграете. Действительно легко. Однако, когда я печатаю свой выходной вектор, я всегда получаю что-то вроде этого:
[[0.01347399 0.04450664]
[0.01347399 0.04450664]
[0.01347399 0.04450664]
[0.1216775 0.38299465]
[0.01347399 0.04450664]]
Это портит функцию обучения, потому что np.argmax()
затем возвращает что-то вроде или 5, и сеть не может справиться с этим, учитывая тот факт, что для начала есть только 2 действия.
Это инициация моей модели:
def __init__(self, state_shape, num_actions, lr):
super(DQN, self).__init__()
self.state_shape = state_shape # (1,)
self.num_actions = num_actions # 2
self.lr = lr # 1e-3
input_state = Input(shape=state_shape)
x = Dense(20)(input_state)
x = Activation('relu')(x)
x = Dense(20)(x)
x = Activation('relu')(x)
output_pred = Dense(self.num_actions)(x)
self.model = Model(inputs=input_state, outputs=output_pred)
self.model.compile(loss="mse", optimizer=Adam(lr=self.lr))
Полный код доступен по адресу https://www.mediafire.com/file/rq7ogjxpr990e51/dqn.py/file.
Как мне обрезать выходной вектор? Или как мне изменить свои входные данные, чтобы получить полезный вывод?
Редактировать:
Я немного поэкспериментировал, и увеличение num_actions от 2 до, например, 4 действительно увеличивает вектор по горизонтали, поэтому он выглядит следующим образом:
[[ 0.00109814 0.01464381 -0.00270887 -0.00422738]
[ 0.00109814 0.01464381 -0.00270887 -0.00422738]
[-0.01450843 0.10628925 -0.06114068 -0.10908635]
[ 0.00109814 0.01464381 -0.00270887 -0.00422738]
[ 0.00109814 0.01464381 -0.00270887 -0.00422738]]
Это означает, что num_actions как 2 не проблема, скорее, он выдает 5 строк вместо 1.