У меня проблемы с вычислением градиента с помощью градиентной ленты при использовании tf.function
с условными переходами.
Внутри области градиентной ленты я пытаюсь вычислить градиент z
wrt self.LMn
. Это отлично работает, когда я не аннотирую функцию @tf.function
. Ошибка возникает в подклассе функции вызова tf.keras.layers.Layer
:
def call(self, x, training=None, labels=None, cur_switch=None, basis_filters=None):
if basis_filters is None:
basis_filters = self.initial_basis_filters
gap = tf.reduce_mean(x, axis=[1, 2])
z = tf.einsum('bc,cd->bd', gap, self.LMn)
# ... code for out
return out, z
Фактическая ошибка определяется следующим образом:
....
C:\...\ops\cond_v2.py:387 <lambda>
lambda: _grad_fn(func_graph, grads), [], {},
C:\...\ops\cond_v2.py:363 _grad_fn
assert len(func_graph.outputs) == len(grads)
AssertionError:
Точнее,
func_graphs.outputs = [<tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/Identity:0' shape=(128, 32, 32, None) dtype=float32>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue:0' shape=() dtype=variant>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue_1:0' shape=() dtype=variant>]
и
grads = (<tf.Tensor 'gradient_tape/vgg16/block1a/block1a_conv/strided_slice_3/StridedSliceGrad_1:0' shape=(128, 32, 32, None) dtype=float32>,)
Я могу предположить, что каждый из этих выходов соответствует случаям для некоторого набора выходов условных ветвей, которые затем передаются в функцию tf.einsum
. Я прочитал все крайние случаи и меры предосторожности в документации по градиентной ленте, так как это кажется проблемой. Просто обратите внимание, что я выполняю только условные вычисления с использованием гиперпараметров (переданных в tf.function
как переменные pythoni c, например basis_filters
). Есть также некоторые условные переходы, использующие (tenorflow ops) функции этих гиперпараметров, это разрешено? или мне нужно вычислить эти значения вне tf.function
и передать их также как переменные pythoni c?
Я знаю, что вопрос не совсем ясен, и при необходимости могу предоставить любую дополнительную информацию. Было бы очень полезно получить рекомендации о том, что искать в случае такого рода проблем.
Спасибо!