Я пытаюсь построить пользовательский слой в TensorFlow, используя коэффициенты матриц в экспоненте, чтобы у слоя были базисные матрицы M1 и M2 и подходящие коэффициенты a и b, чтобы слой действовал на входной вектор с матрицей exp (a M1 + b M2).
Для этого градиента нет решения в замкнутой форме, и TensorFlow все равно не может принять градиент экспоненциальной матрицы, поэтому мне нужно реализовать собственный градиент относительно а и б в моем классе слоя. Вот код для слоя:
class Matrix(layers.Layer):
"""Class for the linear transformation layer in the network"""
def __init__(self, dim=1):
# This won't work for arbitrary dimensionality
dims = [1, 2, 3]
assert dim in dims, "Dimensionality {} is not 1, 2, or 3".format(dim)
super(Matrix, self).__init__()
A_init = tf.random_normal_initializer()
self.dim = dim
# define the basis matrices to generate the Lie transform
if self.dim == 1:
M1 = np.array([[0., -1.],
[1., 0.]]) # basis 1
M2 = np.array([[1., 0.],
[0.,-1.]]) # basis 2
self.basis_matrices=[M1, M2]
self.A = tf.Variable(initial_value=A_init(shape=(2,),
dtype='float32'),
trainable=True)
# dims 2 and 3 are still pending
def compute_matrix_exp(self, zed):
if self.dim == 1:
exp_arg = self.A[0]*self.basis_matrices[0] + self.A[1]*self.basis_matrices[1]
M = tf.linalg.expm(exp_arg)
return tf.matmul(M, zed)
def compute_matrix_exp_grad(self, zed):
if self.dim == 1:
exp_arg = self.A[0]*self.basis_matrices[0] + self.A[1]*self.basis_matrices[1]
M = tf.linalg.expm(exp_arg)
dM_dA0 = tf.matmul(self.basis_matrices[0], M) #approximate
dM_dA1 = tf.matmul(self.basis_matrices[1], M) #approximate
return [tf.matmul(dM_dA0, zed), tf.matmul(dM_dA1, zed)]
@tf.custom_gradient
def call(self, zed):
def grad(zed):
grad = self.compute_matrix_exp_grad(zed)
return self.compute_matrix_exp(zed), grad
Я довольно новичок в TensorFlow, поэтому я не уверен, как лучше реализовать пользовательский градиент в любом случае. Большое спасибо за вашу помощь.
Редактировать: Я добавил свою попытку вывести пользовательский градиент и обнаружил, что он не обучает переменные a и b, ни кажется ли, что он даже звонит call()
, поэтому я в замешательстве.