Условные переходы с использованием tf.function - PullRequest
2 голосов
/ 29 мая 2020

У меня проблемы с вычислением градиента с помощью градиентной ленты при использовании tf.function с условными переходами.

Внутри области градиентной ленты я пытаюсь вычислить градиент z wrt self.LMn. Это отлично работает, когда я не аннотирую функцию @tf.function. Ошибка возникает в подклассе функции вызова tf.keras.layers.Layer:

def call(self, x, training=None, labels=None, cur_switch=None, basis_filters=None):
    if basis_filters is None:
        basis_filters = self.initial_basis_filters

    gap = tf.reduce_mean(x, axis=[1, 2])
    z = tf.einsum('bc,cd->bd', gap, self.LMn)

    # ... code for out

    return out, z

Фактическая ошибка определяется следующим образом:

....
C:\...\ops\cond_v2.py:387 <lambda>
    lambda: _grad_fn(func_graph, grads), [], {},
C:\...\ops\cond_v2.py:363 _grad_fn
    assert len(func_graph.outputs) == len(grads)

AssertionError: 

Точнее,

func_graphs.outputs = [<tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/Identity:0' shape=(128, 32, 32, None) dtype=float32>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue:0' shape=() dtype=variant>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue_1:0' shape=() dtype=variant>]

и

grads = (<tf.Tensor 'gradient_tape/vgg16/block1a/block1a_conv/strided_slice_3/StridedSliceGrad_1:0' shape=(128, 32, 32, None) dtype=float32>,)

Я могу предположить, что каждый из этих выходов соответствует случаям для некоторого набора выходов условных ветвей, которые затем передаются в функцию tf.einsum. Я прочитал все крайние случаи и меры предосторожности в документации по градиентной ленте, так как это кажется проблемой. Просто обратите внимание, что я выполняю только условные вычисления с использованием гиперпараметров (переданных в tf.function как переменные pythoni c, например basis_filters). Есть также некоторые условные переходы, использующие (tenorflow ops) функции этих гиперпараметров, это разрешено? или мне нужно вычислить эти значения вне tf.function и передать их также как переменные pythoni c?

Я знаю, что вопрос не совсем ясен, и при необходимости могу предоставить любую дополнительную информацию. Было бы очень полезно получить рекомендации о том, что искать в случае такого рода проблем.

Спасибо!

...