Я пытаюсь реализовать алгоритм REINFORCE для пространства непрерывного действия. Я создал игрушку, MDP с одним состоянием, где награда отрицательна от абсолютной разницы между выбранным действием и целевым значением. Оптимальная стратегия - чтобы действие было идентичным целевому значению. Диапазон возможных действий от -inf до inf. Например, если целевое значение равно 5, и агент выполняет действие 2.2, вознаграждение составляет -2.8 = -abs (2.2-5), и эпизод заканчивается. Нейронная сеть с одним линейным нейроном предсказывает среднее гауссовского распределения, которое имеет фиксированное стандартное отклонение, из которого агент выбирает свое действие. Состояние MDP имеет вектор признаков, который равен [1.0].
Когда стандартное отклонение мало, скажем, 0,05, сеть хорошо сходится. Однако, когда стандартное отклонение больше, скажем, 0,5, прогноз сети расходится в направлении, противоположном целевому значению.
Вот код:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tf.enable_eager_execution()
if __name__ == '__main__':
mean_input_layer = tf.keras.Input(shape=(1,))
mean_output_layer = tf.keras.layers.Dense(
1, activation='linear')(mean_input_layer)
mean_model = tf.keras.Model(inputs=mean_input_layer,
outputs=mean_output_layer)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
target_mean = 5
standard_deviation = 0.05 # Works
# standard_deviation = 0.5 # Diverges
for t in range(10000):
with tf.GradientTape() as tape:
mean = mean_model(np.array([[1]]), training=True)
stdev = np.array([[standard_deviation]], dtype=np.float32)
dist = tfp.distributions.Normal(loc=mean, scale=stdev)
samp = dist.sample()
ascent = -(dist.log_prob(samp)*(-tf.math.abs(samp-target_mean)))
mean_grads = tape.gradient(ascent, mean_model.trainable_variables)
optimizer.apply_gradients(
zip(mean_grads, mean_model.trainable_variables))
print('#' + str(t)
+ ' Sample: ' + str(samp.numpy()[0])
+ ' Mean: ' + str(mean.numpy()[0]))
Переменная «восхождение» является отрицательной с намерением выполнить градиентное восхождение, а не спуск.
Вот пример прогона с целевым значением 5 и стандартным отклонением 0,05:
# 0 Образец: [-0,3368746] Среднее значение: [-0,400553]
# 1 Образец: [-0,35207522] Среднее значение: [-0,3980214]
# 2 Образец: [-0,39965397] Среднее: [-0,3947122]
# 3 Образец: [-0,3883655] Среднее значение: [-0,39056838]
...
# 2460 Образец: [5.0231776] Среднее: [4.9940543]
# 2461 Образец: [5.030905] Среднее: [4.99024]
# 2462 Образец: [4.853626] Среднее: [4.9867477]
# 2463 Образец: [4.8647094] Среднее: [4.983813]
# 2464 Образец: [4.9929175] Среднее: [4.982292]
сходится.
Вот пример прогона с целевым значением 5 и стандартным отклонением 0,5:
# 1 Образец: [0.6297094] Среднее: [1.4340767]
# 2 Образец: [0.75481075] Среднее: [1.4310371]
# 3 Образец: [0.9269088] Среднее: [1.4287564]
# 4 Образец: [1.2933123] Среднее: [1.4272974]
...
# 3210 Образец: [-3,4329443] Среднее значение: [-3,322072]
# 3211 Образец: [-3,755511] Среднее: [-3,3225727]
# 3212 Образец: [-3,6817236] Среднее значение: [-3,3237739]
# 3213 Образец: [-3,4897459] Среднее: [-3,324738]
...
# 9996 Образец: [-13.280873] Среднее: [-13.032175]
# 9997 Образец: [-13.879341] Среднее: [-13.032874]
# 9998 Образец: [-12.796365] Среднее: [-13.036192]
# 9999 Образец: [-13.04003] Среднее: [-13.036874]
Он расходится в противоположном направлении от целевого значения.
Что тут происходит? Есть ли ошибка в моей реализации? Спасибо за помощь.
Обновление: Если заменить "samp = dist.sample ()" на выборку из равномерного распределения с центром в среднем с шириной + - 2 стандартных отклонения, случай стандартного отклонения 0,5 сходится , В частности, «samp = tf.constant ([[mean.numpy () [0] - 2 * стандартное отклонение + 4 * стандартное отклонение * np.random.random ()]])». Это явно не реальное решение проблемы, но может дать некоторое представление о том, кто пытается выяснить, что здесь происходит не так. Я все еще ищу решение.
Обновление 2: В качестве еще одной части головоломки, если я переверну знак переменной всплытия, тогда случай стандартного отклонения 0,5 сходится, а случай стандартного отклонения 0,05 расходится.