Для политики сети в вашем коде, я думаю, вы хотите что-то вроде этого:
loss = -(log(action_probability) * reward)
Где action_probability - это выход вашей сети для действия, выполненного на этом временном шаге.
Например, если ваша сеть выдала 10% -ную вероятность предпринять это действие, но она дала вознаграждение 10, ваша потеря будет: -(log(0.1) * 10)
, что равно 10.
Но, если ваша сеть уже посчитала, что это хороший шаг и выдает 90% -ную вероятность предпринять это действие, у вас будет -log(0.9) * 10)
, что примерно равно 0,45, меньше влияя на сеть.
Стоит отметить, что функция PyTorch log
не является численно стабильной, и вам может быть лучше использовать logsoftmax
на последнем уровне вашей сети.