Градиент TensorFlow с tf.where возвращает NaN, когда он не должен - PullRequest
0 голосов
/ 05 мая 2018

Ниже приведен воспроизводимый код. Если вы запустите его, вы увидите, что в первом сеансе sess результатом будет Nan, тогда как во втором случае правильное значение градиента равно 0,5. Но согласно указанным условиям и условиям они должны возвращать одно и то же значение. Я также просто не понимаю, почему градиент функции tf.where равен nan при 1 или -1, что мне кажется вполне подходящим входным значением.

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

Спасибо за комментарии!

1 Ответ

0 голосов
/ 05 мая 2018

Как объясняется в проблеме github , предоставленной @mikkola, проблема связана с внутренней реализацией tf.where. В основном, обе альтернативы (и их градиент) вычисляются, и только умножение условной выборки выбирает только правильную часть. Увы, если градиент равен inf или nan для детали, которая не выбрана, даже при умножении на 0 вы получите nan, который в конечном итоге распространяется на результат.

Поскольку проблема была подана в мае 2016 года (это тензор потока v0.7!) И с тех пор не исправлена, можно смело предположить, что это произойдет не скоро, и начать искать обходной путь.

Самый простой способ исправить это - изменить ваши операторы так, чтобы они всегда были действительными и дифференцируемыми даже для значений, которые не предназначены для выбора.

Общий метод заключается в том, чтобы обрезать входное значение внутри его допустимого домена. Так, в вашем случае, например, вы можете использовать

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

В вашем конкретном случае, однако, было бы проще использовать

output = tf.sign(x) * tf.log(tf.abs(x) + 1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...