Я пытаюсь создать собственный оптимизатор TensorFlow ( tf.keras.optimizers.Optimizer ), который по-разному обрабатывает веса различных форм.
Например, рассмотрите простой сверточная нейронная сеть со следующей формой весов (и смещений):
(3, 3, 3, 16)
(16,)
(3, 3, 16, 16)
(16,)
(2704, 64)
(64,)
(64, 10)
(10,)
В начале метода _resource_apply_dense(self, grad, var)
в моем настраиваемом оптимизаторе я хотел бы преобразовать var
из разные формы все в 2-х мерном.
Ниже приводится упрощенный лог c желаемого поведения:
def custom_train_step(var):
if (tf.rank(var) == 1):
print("Case1")
return tf.expand_dims(input=var, axis=0)
elif tf.rank(var) == 2:
print("Case2")
return tf.transpose(a=var)
elif tf.rank(var) == 4:
print("Case3")
var = tf.transpose(a=var, perm=(1,0,3,2))
return flatten_to_2d(var)
else:
# omitted
pass
Однако это не будет работать, когда ndim(var)<4
, поскольку кажется, что когда TensorFlow строит свой граф вычислений , отслеживаются все 4 ветви, в том числе Case3
. Другими словами, при использовании текущей реализации 1d и 2d var
во время трассировки также будут переданы в tf.transpose(a=var, perm=(1,0,3,2))
, что приведет к ошибкам:
ValueError: Dimension must be 1 but is 4 for '{{node transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](transpose/ReadVariableOp, Const)' with input shapes: [16], [4].
( Ошибка возникла, когда var
- это тензор смещения формы (16,)
)
Ошибка может быть воспроизведена путем декорирования вышеуказанного метода с помощью tf.function
. Вот ссылка Colab на этот образец игрушки.
Я пробовал напрямую писать условные операторы, используя tf.cond
, tf.case
и tf.switch_case
, но ошибка осталась. Я понимаю, что это связано с тем, что метод custom_train_step(var)
является polymorphi c, что сделало его необходимым для повторной трассировки, но я не могу придумать способа избежать такого поведения, улучшив код. (Обратите внимание еще раз, что я, вероятно, не могу написать 4 ветки в отдельных методах и украсить каждую из них tf.function
, потому что это должно вызываться внутри tf.keras training l oop. Пожалуйста, поправьте меня, если я ошибаюсь .)
Я хотел бы знать, есть ли обходной путь для достижения того, что я описал выше, или он еще не поддерживается Tensorflow?
Любая помощь и предложения будут оценены! При необходимости можно предоставить более подробную информацию. Спасибо!