использование сборки на argmax отличается от использования max - PullRequest
0 голосов
/ 27 мая 2020

Я пытаюсь научиться обучать алгоритм двойного DQN на тензорном потоке, и он не работает. Чтобы убедиться, что все в порядке, я хотел что-то протестировать. Я хотел убедиться, что использование tf.gather для argmax точно такое же, как и максимальное значение: допустим, у меня есть сеть с именем target_network:

сначала возьмем max:

next_qvalues_target1 = target_network.get_symbolic_qvalues(next_obs_ph) #returns tensor of qvalues
next_state_values_target1 = tf.reduce_max(next_qvalues_target1, axis=1)

давайте попробуем по-другому - используя argmax и соберите:

next_qvalues_target2 = target_network.get_symbolic_qvalues(next_obs_ph) #returns same tensor of qvalues
chosen_action = tf.argmax(next_qvalues_target2, axis=1)
next_state_values_target2 = tf.gather(next_qvalues_target2, chosen_action)

diff = tf.reduce_sum(next_state_values_target1) - tf.reduce_sum(next_state_values_target2)

next_state_values_target2 и next_state_values_target1 должны быть полностью идентичными. поэтому запуск сеанса должен выводить diff =. но это не так.

Что мне не хватает?

Спасибо.

1 Ответ

1 голос
/ 27 мая 2020

Выяснили, что пошло не так. выбранное действие имеет форму (n, 1), поэтому я подумал, что, используя сборку для переменной, которая (n, 4), я получу результат shape (n, 1). оказывается, это неправда. Мне нужно было превратить selected_action в переменную формы (n, 2) - вместо [action1, action2, action3 ...] мне нужно было, чтобы она была [[1, action1], [2, action2], [3, action3] ....] и используйте gather_nd, чтобы иметь возможность брать определенные c элементы из next_qvalues_target2, а не собирать, потому что gather принимает полные строки.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...