Попытка создать собственный оптимизатор TensorFlow, который ведет себя по-разному в зависимости от формы весов - PullRequest
0 голосов
/ 07 мая 2020

Я пытаюсь создать собственный оптимизатор TensorFlow ( tf.keras.optimizers.Optimizer ), который по-разному обрабатывает веса различных форм.

Например, рассмотрите простой сверточная нейронная сеть со следующей формой весов (и смещений):

(3, 3, 3, 16)
(16,)
(3, 3, 16, 16)
(16,)
(2704, 64)
(64,)
(64, 10)
(10,)

В начале метода _resource_apply_dense(self, grad, var) в моем настраиваемом оптимизаторе я хотел бы преобразовать var из разные формы все в 2-х мерном.

Ниже приводится упрощенный лог c желаемого поведения:

def custom_train_step(var):
    if (tf.rank(var) == 1):
        print("Case1")
        return tf.expand_dims(input=var, axis=0)
    elif tf.rank(var) == 2:
        print("Case2")
        return tf.transpose(a=var)
    elif tf.rank(var) == 4:
        print("Case3")
        var = tf.transpose(a=var, perm=(1,0,3,2))
        return flatten_to_2d(var)
    else:
        # omitted
        pass

Однако это не будет работать, когда ndim(var)<4, поскольку кажется, что когда TensorFlow строит свой граф вычислений , отслеживаются все 4 ветви, в том числе Case3. Другими словами, при использовании текущей реализации 1d и 2d var во время трассировки также будут переданы в tf.transpose(a=var, perm=(1,0,3,2)), что приведет к ошибкам:

ValueError: Dimension must be 1 but is 4 for '{{node transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](transpose/ReadVariableOp, Const)' with input shapes: [16], [4].

( Ошибка возникла, когда var - это тензор смещения формы (16,))

Ошибка может быть воспроизведена путем декорирования вышеуказанного метода с помощью tf.function. Вот ссылка Colab на этот образец игрушки.

Я пробовал напрямую писать условные операторы, используя tf.cond, tf.case и tf.switch_case, но ошибка осталась. Я понимаю, что это связано с тем, что метод custom_train_step(var) является polymorphi c, что сделало его необходимым для повторной трассировки, но я не могу придумать способа избежать такого поведения, улучшив код. (Обратите внимание еще раз, что я, вероятно, не могу написать 4 ветки в отдельных методах и украсить каждую из них tf.function, потому что это должно вызываться внутри tf.keras training l oop. Пожалуйста, поправьте меня, если я ошибаюсь .)

Я хотел бы знать, есть ли обходной путь для достижения того, что я описал выше, или он еще не поддерживается Tensorflow?

Любая помощь и предложения будут оценены! При необходимости можно предоставить более подробную информацию. Спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...