Я создаю несколько пользовательских операций в Tensorflow с соответствующими градиентами.Все отлично работает в отдельности, но я сталкиваюсь с проблемой, когда 2 из моих пакетов определяют две разные операции (разные входы) с одним и тем же именем.
Чтобы упростить мой вопрос, представьте, что я определен как matmul
работа в двух пакетах.Его можно легко использовать, как в следующем коде:
import tensorflow as tf
my_ops_a = tf.load_op_library('libpackage_a.so')
my_ops_b = tf.load_op_library('libpackage_b.so')
x, y = tf.random.uniform(10,10), tf.random.uniform(10,10)
my_ops_a.matmul(x, y)
my_ops_b.matmul(x, y)
И его градиент может быть передан в Tensorflow как:
from tensorflow.python.framework import ops as tf_ops
@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
return my_ops_a.mat_mul_grad(grad, op.inputs[0], op.inputs[1])
@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
return my_ops_b.mat_mul_grad(grad, op.inputs[0], op.inputs[1])
Однако @tf_ops.RegisterGradient
не имеет никакого способаопределить, на что matmul
я ссылаюсь.
На самом деле, когда я пытаюсь запустить этот код, я получил следующую ошибку:
KeyError: "Registering two gradient with name 'Matmul'! (Previous registration was in <module> ...)
Как я могу сообщить Tensorflow, что я имею в виду операцию определенного пакета?
Заранее спасибо.