Я реализовал некоторые рекуррентные модели с TensorFlow, используя tf.scan, и в настоящее время я пытаюсь получить градиент выходных потерь по состоянию на каждом временном шаге для исследовательских целей.
Ключевой процесс может бытьсводится к вычислению градиента по промежуточным значениям внутри тела while_loop:
import tensorflow as tf
tf.InteractiveSession()
sequence = tf.constant([2, 3, 4, 5])
def body(t, previous_s):
intermediate_s = previous_s * sequence[t]
return t + 1, intermediate_s
t0, s0 = 0, 1
_, s4 = tf.while_loop(lambda t, _: t < sequence.shape[0], body, (t0, s0))
print('s4 =', s4.eval())
Выходное значение равно
s4 = 120
, а промежуточные результаты, вычисленные во время цикла while, равны
s1 = 2, s2 = 6, s3 = 24
Итак, есть ли способ доступа к этим тензорам, чтобы я мог рассчитать соответствующий градиент, используя tf.gradient, скажем, для dS4/dS1
:
ds4_ds1 = tf.gradient(s4, s1)[0]