tf.GradientTape () не работает с нарезанными выводами - PullRequest
2 голосов
/ 03 августа 2020

Вот фрагмент кода, который я пытался запустить:

import tensorflow as tf

a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)

with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
    tape1.watch(a)
    tape2.watch(a)
    
    c = a * b

grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c[:, 0], a)
print(grad1)
print(grad2)

И это результат:

tf.Tensor(
[[1. 2.]
 [2. 3.]], shape=(2, 2), dtype=float32)
None

Как вы можете заметить, tf.GradientTape () является не работает с нарезанными выходами. Есть ли способ обойти это?

1 Ответ

3 голосов
/ 03 августа 2020

Да, все, что вы делаете с тензорами, должно происходить внутри контекста ленты. Вы можете исправить это относительно легко следующим образом:

import tensorflow as tf

a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)

with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
    tape1.watch(a)
    tape2.watch(a)
    
    c = a * b
    c_sliced = c[:, 0]

grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c_sliced, a)
print(grad1)
print(grad2)
...