Использование автографа при расчете градиента по tf.case - PullRequest
0 голосов
/ 30 апреля 2020

Я пытаюсь вычислить градиент по tf.case, используя автограф.

Например, допустим, у меня есть функция case, где она принимает пакет ввода и вычисляет вывод на основе знака input:

def case_fn(x):                                                                                                                                                                                                                                                                                                          
    N = tf.shape(x)[0]                                                                                                                                                                                                                                                                                                   
    positive_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.greater(x, 0.)))),tf.int32)                                                                                                                                                                                                                            
    negative_idx = tf.cast(tf.squeeze(tf.where(tf.squeeze(tf.math.less_equal(x, 0.)))),tf.int32)                                                                                                                                                                                                                         
    def all_positive_case():                                                                                                                                                                                                                                                                                             
        y_positive = x*2.                                                                                                                                                                                                                                                                                                

        return y_positive                                                                                                                                                                                                                                                                                                

    def all_negative_case():                                                                                                                                                                                                                                                                                             
        y_negative = x-2.                                                                                                                                                                                                                                                                                                

        return y_negative                                                                                                                                                                                                                                                                                                

    def some_positive_some_negative_case():                                                                                                                                                                                                                                                                              
        x_positive = tf.gather(x, positive_idx)                                                                                                                                                                                                                                                                          
        x_negative = tf.gather(x, negative_idx)                                                                                                                                                                                                                                                                          

        y_positive = x_positive*2.                                                                                                                                                                                                                                                                                       
        y_negative = x_negative-2.                                                                                                                                                                                                                                                                                       

        y_positive = tf.scatter_nd(tf.expand_dims(positive_idx,1),y_positive,tf.stack([N,1]))                                                                                                                                                                                                                            
        y_negative = tf.scatter_nd(tf.expand_dims(negative_idx,1),y_negative,tf.stack([N,1]))                                                                                                                                                                                                                            

        return y_positive + y_negative                                                                                                                                                                                                                                                                                   

    all_positive = tf.math.equal(tf.shape(negative_idx)[0], 0)                                                                                                                                                                                                                                                           
    all_negative = tf.math.equal(tf.shape(positive_idx)[0], 0)                                                                                                                                                                                                                                                           
    return tf.case([(all_positive, all_positive_case), (all_negative, all_negative_case)], default=some_positive_some_negative_case)

Затем я вычисляю градиент с помощью следующего кода:

trainable_variable = tf.Variable([[1.], [-1.], [2.], [-2.]])                                                                                                                                                                                                                                                             
@tf.function                                                                                                                                                                                                                                                                                                             
def compute_grad():                                                                                                                                                                                                                                                                                                      
    with tf.GradientTape() as tape:                                                                                                                                                                                                                                                                                      
        y = case_fn(trainable_variable)                                                                                                                                                                                                                                                                                  
    grad = tape.gradient(y, trainable_variable)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
    return grad                                                                                                                                                                                                                                                                                                          

print(compute_grad())   

Если я не использую @tf.function декоратор, он возвращает правильное значение, равное IndexedSlices(indices=tf.Tensor([0, 2, 1, 3], shape=(4,), dtype=int32), values=tf.Tensor([[2.],[2.],[1.],[1.]], shape=(4, 1), dtype=float32), dense_shape=tf.Tensor([4 1], shape=(2,), dtype=int32)). Однако, если я использую @tf.function декоратор, он возвращает ошибку значения, говорящую

Traceback (most recent call last):
  File "examples/case_gradient.py", line 102, in <module>
    print(compute_grad())
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    examples/case_gradient.py:99 compute_grad  *
        grad = tape.gradient(y, trainable_variable)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:1029 gradient
        unconnected_gradients=unconnected_gradients)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/imperative_grad.py:77 imperative_grad
        compat.as_str(unconnected_gradients.value))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/eager/backprop.py:141 _gradient_function
        return grad_fn(mock_op, *out_grads)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:121 _IfGrad
        false_graph, grads, util.unique_grad_fn_name(false_graph.name))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:381 _create_grad_func
        func_graph=_CondGradFuncGraph(name, func_graph))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:380 <lambda>
        lambda: _grad_fn(func_graph, grads), [], {},
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:371 _grad_fn
        src_graph=func_graph)
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 _GradientsHelper
        lambda: grad_fn(op, *out_grads))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:336 _MaybeCompile
        return grad_fn()  # Exit early
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/gradients_util.py:669 <lambda>
        lambda: grad_fn(op, *out_grads))
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:183 _IfGrad
        building_gradient=True,
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:219 _build_cond
        _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    /home/junhyeok/.venv/ccmbrl/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:652 _make_indexed_slices_indices_types_match
        (current_index, len(branch_graphs[0].outputs)))

    ValueError: Insufficient elements in branch_graphs[0].outputs.
    Expected: 6
    Actual: 3

Чего мне здесь не хватает?

1 Ответ

1 голос
/ 30 апреля 2020

Я проверил последнюю версию 2.2.0-rc3 и не вижу этой проблемы. Это может быть решено в новой версии.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...