В случае tf.where
у вас есть функция с тремя входами, условием C
, значением true = 1003 * и значением false F
и одним выходом Out
. Градиент получает одно значение и должен возвращать три значения. В настоящее время градиент для условия не вычисляется (что вряд ли имело бы смысл), поэтому вам просто нужно сделать градиенты для T
и F
. Предполагая, что вход и выход являются векторами, представьте, что C[0]
равно True
. Тогда Out[0]
происходит от T[0]
, и его градиент должен распространяться обратно. С другой стороны, F[0]
был бы отброшен, поэтому его градиент должен быть равен нулю. Если Out[1]
было False
, то градиент для F[1]
должен распространяться, но не для T[1]
. Итак, вкратце, для T
вы должны распространить заданный градиент, где C
равен True
, и сделать его нулевым, если он равен False
, и наоборот для F
. Если вы посмотрите на реализацию градиента tf.where
(Select
операция) , он сделает именно это:
@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
return (None, array_ops.where(c, grad, zeros), array_ops.where(
c, zeros, grad))
Обратите внимание, что сами входные значения не используются в вычислениях, которые будут выполняться градиентами операции, производящей эти входные данные. Для tf.cond
, код немного сложнее , потому что одна и та же операция (Merge
) используется в разных контекстах, а также tf.cond
также использует Switch
операции внутри. Однако идея та же самая. По существу, операции Switch
используются для каждого входа, поэтому активированный вход (первый, если условие было True
, и второй в противном случае) получает полученный градиент, а другой вход получает градиент «выключен» (например, None
) и не распространяется дальше.