Как тензорный поток обрабатывает недифференцируемые узлы во время вычисления градиента? - PullRequest
0 голосов
/ 08 ноября 2018

Я понял концепцию автоматического дифференцирования, но не смог найти никакого объяснения, как tenorflow вычисляет градиент ошибки для недифференцируемых функций, как, например, tf.where в моей функции потерь или tf.cond в моем графике. Это работает просто отлично, но я хотел бы понять, как тензорный поток распространяет ошибку через такие узлы, поскольку нет формулы для вычисления градиента по ним.

1 Ответ

0 голосов
/ 08 ноября 2018

В случае tf.where у вас есть функция с тремя входами, условием C, значением true = 1003 * и значением false F и одним выходом Out. Градиент получает одно значение и должен возвращать три значения. В настоящее время градиент для условия не вычисляется (что вряд ли имело бы смысл), поэтому вам просто нужно сделать градиенты для T и F. Предполагая, что вход и выход являются векторами, представьте, что C[0] равно True. Тогда Out[0] происходит от T[0], и его градиент должен распространяться обратно. С другой стороны, F[0] был бы отброшен, поэтому его градиент должен быть равен нулю. Если Out[1] было False, то градиент для F[1] должен распространяться, но не для T[1]. Итак, вкратце, для T вы должны распространить заданный градиент, где C равен True, и сделать его нулевым, если он равен False, и наоборот для F. Если вы посмотрите на реализацию градиента tf.where (Select операция) , он сделает именно это:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

Обратите внимание, что сами входные значения не используются в вычислениях, которые будут выполняться градиентами операции, производящей эти входные данные. Для tf.cond, код немного сложнее , потому что одна и та же операция (Merge) используется в разных контекстах, а также tf.cond также использует Switch операции внутри. Однако идея та же самая. По существу, операции Switch используются для каждого входа, поэтому активированный вход (первый, если условие было True, и второй в противном случае) получает полученный градиент, а другой вход получает градиент «выключен» (например, None) и не распространяется дальше.

...