Определение пользовательского градиента как метода класса в Tensorflow - PullRequest
0 голосов
/ 22 февраля 2019

Мне нужно определить метод в качестве пользовательского градиента следующим образом:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx, grad

Я получаю следующую ошибку:

ValueError: Попытка преобразовать значение (<<strong> main .CustGradClass объект в 0x12ed91710>) с неподдерживаемым типом () для Tensor.

Причина в том, что пользовательский градиент принимает функцию f (* x) где x - последовательность тензоров.И первым передаваемым аргументом является сам объект, т. Е. self .

Из документации :

f: функция f (* x), которая возвращает кортеж (y, grad_fn), где: x - это последовательность входов Tensor для функции .y является Tensor или последовательностью выходных данных Tensor применения операций TensorFlow в f к x.grad_fn - это функция с сигнатурой g (* grad_ys)

Как мне заставить это работать?Нужно ли мне наследовать некоторый класс тензорного потока Python?

Я использую tf версии 1.12.0 и режим ожидания.

Ответы [ 2 ]

0 голосов
/ 22 февраля 2019

Это один из возможных простых обходных путей:

import tensorflow as tf

class CustGradClass:

    def __init__(self):
        self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))

    @staticmethod
    def _f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant(1.0)
    c = CustGradClass()
    y = c.f(x)
    print(tf.gradients(y, x))
    # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]

РЕДАКТИРОВАТЬ:

Если вы хотите делать это много раз в разных классах или просто хотите более многоразовое решение, выможно использовать такой декоратор, например, как:

import functools
import tensorflow as tf

def tf_custom_gradient_method(f):
    @functools.wraps(f)
    def wrapped(self, *args, **kwargs):
        if not hasattr(self, '_tf_custom_gradient_wrappers'):
            self._tf_custom_gradient_wrappers = {}
        if f not in self._tf_custom_gradient_wrappers:
            self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
        return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
    return wrapped

Тогда вы можете просто сделать:

class CustGradClass:

    def __init__(self):
        pass

    @tf_custom_gradient_method
    def f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

    @tf_custom_gradient_method
    def f2(self, x):
        fx = x * 2
        def grad(dy):
            return dy * 2
        return fx, grad
0 голосов
/ 22 февраля 2019

В вашем примере вы не используете никаких переменных-членов, поэтому вы можете просто сделать метод статическим методом.Если вы используете переменные-члены, тогда вызовите статический метод из функции-члена и передайте переменные-члены в качестве параметров.

class CustGradClass:

  def __init__(self):
    self.some_var = ...

  @staticmethod
  @tf.custom_gradient
  def _f(x):
    fx = x
    def grad(dy):
      return dy * 1

    return fx, grad

  def f(self):
    return CustGradClass._f(self.some_var)
...