Вам необходимо рассчитать потери только по Q-значению, для которого выбрано его действие.В вашем примере предположим, что для данной строки в вашей мини-партии действие равно 3
.Затем вы получаете соответствующую цель, y_3
, и тогда потеря составляет (Q(s,3) - y_3)^2
, и в основном вы устанавливаете значение потери других действий на ноль.Вы можете реализовать это, используя gather_nd
в tensorflow
или получив one-hot-encode
версию действий, а затем умножив этот one-hot-encode
вектор на вектор Q-значения.Используя вектор one-hot-encode
, вы можете написать:
action_input = tf.placeholder("float",[None,action_len])
QValue_batch = tf.reduce_sum(tf.multiply(T_Q_value,action_input), reduction_indices = 1)
, в котором action_input = np.eye(nb_classes)[your_action (e.g. 3)]
.За такой же процедурой можно следовать gather_nd
: https://www.tensorflow.org/api_docs/python/tf/gather_nd
Надеюсь, это разрешит вашу путаницу.