Прежде всего умножение матриц не является коммутативным, поэтому (AXB) не равно (BXA), поэтому tf.matmul (A, B) не равно tf.matmul (B, A).
Учитывая ввод с размером (d (размер пакета), n (длина вектора)), вы хотите применить свою функцию Ax + b для каждого вектора в пакете. Вот код и вывод.
Вход
[[ 0. 1. 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11. 12. 13.]]
Код
class custom_layer(keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(custom_layer, self).__init__(*args, **kwargs)
# Defines A and b to be trainable
def build(self, input_shape):
self.weight = self.add_weight(shape=(1,input_shape[1]),
initializer='ones',
trainable=True)
self.bias = self.add_weight(shape=(input_shape[0]),
initializer='zeros',
trainable=True)
super(custom_layer, self).build(input_shape)
def call(self, x):
# Apply Linear Regression
x_out = tf.matmul(self.weight,x,transpose_b=True) + self.bias
# Concatenate map output with input to form graph!
x_out = tf.concat([x,tf.transpose(x_out)],-1)
return x_out
Выход
[[ 0. 1. 2. 3. 4. 5. 6. 21.]
[ 7. 8. 9. 10. 11. 12. 13. 70.]]
Если вы хотите применить Ax + б элемент мудрый. Простое слово, каждый элемент имеет свой вес и уклон.
Вход
[[ 0. 1. 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11. 12. 13.]]
Код
class custom_layer(keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(custom_layer, self).__init__(*args, **kwargs)
# Defines A and b to be trainable
def build(self, input_shape):
self.weight = self.add_weight(shape=(1,input_shape[1]),
initializer='ones',
trainable=True)
self.bias = self.add_weight(shape=(1,input_shape[1]),
initializer='zeros',
trainable=True)
super(custom_layer, self).build(input_shape)
def call(self, x):
# Apply Linear Regression
x_out = tf.multiply(self.weight,x) + self.bias
# Concatenate map output with input to form graph!
x_out = tf.concat([x,x_out],-1)
return x_out
Выход
[[ 0. 1. 2. 3. 4. 5. 6. 0. 1. 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11. 12. 13. 7. 8. 9. 10. 11. 12. 13.]]
Редактировать
Код для тестирования слоя
test = custom_layer()
in1 = Input(shape=(7,))
out = test(in1)
# test
M = Model(inputs=[in1],outputs=[out])
M.compile(keras.optimizers.Adam(),loss='mse')
print(np.arange(14,dtype=np.float32).reshape(2,7))