Как использовать градиент_override_map в Tensorflow 2.0? - PullRequest
2 голосов
/ 19 апреля 2019

Я пытаюсь использовать gradient_override_map с Tensorflow 2.0.В документации есть пример , который я также буду использовать здесь в качестве примера.

В 2.0 GradientTape можно использовать для вычисления градиентов следующим образом:

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha0

x = tf.Variable(5.0)
with tf.GradientTape() as tape:
    s_1 = tf.square(x)
print(tape.gradient(s_1, x))

Существует также декоратор tf.custom_gradient, который можно использовать для определения градиента для новой функции (опять же, используя пример из документации ):

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha

@tf.custom_gradient
def log1pexp(x):
    e = tf.exp(x)

    def grad(dy):
        return dy * (1 - 1 / (1 + e))

    return tf.math.log(1 + e), grad

x = tf.Variable(100.)

with tf.GradientTape() as tape:
    y = log1pexp(x)

print(tape.gradient(y, x))

Однако я хотел бы заменить градиент для стандартных функций, таких как tf.square.Я попытался использовать следующий код:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

with tf.Graph().as_default() as g:
    x = tf.Variable(5.0)
    with g.gradient_override_map({"Square": "CustomSquare"}):
        with tf.GradientTape() as tape:
            s_2 = tf.square(x, name="Square")

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())            
        print(sess.run(tape.gradient(s_2, x)))

Однако есть две проблемы: замена градиента не работает (она оценивается как 10.0 вместо 0.0), и мне нужноприбегнуть к session.run() для выполнения графика.Есть ли способ добиться этого в «родном» TensorFlow 2.0?

В TensorFlow 1.12.0 следующее дает желаемый результат:

import tensorflow as tf
print(tf.__version__)  # 1.12.0

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

x = tf.Variable(5.0)

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s_2 = tf.square(x, name="Square")
grad = tf.gradients(s_2, x)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(grad))

1 Ответ

3 голосов
/ 22 апреля 2019

В TensorFlow 2.0 нет встроенного механизма для переопределения всех градиентов для встроенного оператора в области видимости.Однако, если вы можете изменить сайт вызова для каждого вызова встроенного оператора, вы можете использовать декоратор tf.custom_gradient следующим образом:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(0.0)
  return tf.square(x), grad

with tf.Graph().as_default() as g:
  x = tf.Variable(5.0)
  with tf.GradientTape() as tape:
    s_2 = custom_square(x)

  with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())            
    print(sess.run(tape.gradient(s_2, x)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...