Разрешение имен в пользовательских градиентах Tensorflow - PullRequest
0 голосов
/ 03 мая 2019

Я создаю несколько пользовательских операций в Tensorflow с соответствующими градиентами.Все отлично работает в отдельности, но я сталкиваюсь с проблемой, когда 2 из моих пакетов определяют две разные операции (разные входы) с одним и тем же именем.

Чтобы упростить мой вопрос, представьте, что я определен как matmul работа в двух пакетах.Его можно легко использовать, как в следующем коде:

import tensorflow as tf
my_ops_a = tf.load_op_library('libpackage_a.so')
my_ops_b = tf.load_op_library('libpackage_b.so')

x, y = tf.random.uniform(10,10), tf.random.uniform(10,10)
my_ops_a.matmul(x, y)
my_ops_b.matmul(x, y)

И его градиент может быть передан в Tensorflow как:

from tensorflow.python.framework import ops as tf_ops

@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
    return my_ops_a.mat_mul_grad(grad, op.inputs[0], op.inputs[1])

@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
    return my_ops_b.mat_mul_grad(grad, op.inputs[0], op.inputs[1])

Однако @tf_ops.RegisterGradient не имеет никакого способаопределить, на что matmul я ссылаюсь.

На самом деле, когда я пытаюсь запустить этот код, я получил следующую ошибку:

KeyError: "Registering two gradient with name 'Matmul'! (Previous registration was in <module> ...)

Как я могу сообщить Tensorflow, что я имею в виду операцию определенного пакета?

Заранее спасибо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...