Вы можете определить пользовательский слой Keras для этого, где вы можете передать справочный Dense
слой.
Плотный слой:
class CustomDense(Layer):
def __init__(self, reference_layer):
super(CustomDense, self).__init__()
self.ref_layer = reference_layer
def call(self, inputs):
weights = self.ref_layer.get_weights()[0]
bias = self.ref_layer.get_weights()[1]
weights = tf.transpose(weights)
x = tf.linalg.matmul(inputs, weights) + bias
return x
Теперь вы добавляете этот слой в вашу модель используя Functional-API .
inp = Input(shape=(5))
dense = Dense(5)
transposed_dense = CustomDense(dense)
#model
x = dense(inp)
x = transposed_dense(x)
model = Model(inputs=inp, outputs=x)
model.summary()
'''
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 5)] 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 30
_________________________________________________________________
custom_dense_1 (CustomDense) (None, 5) 30
=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
_________________________________________________________________
'''
Как видите, dense
и custom_dense
имеют 30 общих параметров. Здесь custom_dense
просто выполняет плотную операцию, используя транспонированные веса слоя dense
, и у него нет собственного параметра.
РЕДАКТИРОВАТЬ 1: Ответ на вопрос в комментарии (Как субклассифицированный слой получает # params?):
Класс слоя отслеживает все объекты, передаваемые в его метод __init__
.
transposed_dense._layers
# [<tensorflow.python.keras.layers.core.Dense at 0x7fc3e0874f28>]
Приведенный выше параметр даст зависимые слои, которые отслеживаются. Все веса дочерних атрибутов можно рассматривать как:
transposed_dense._gather_children_attribute("weights")
#[<tf.Variable 'dense_9/kernel:0' shape=(10, 5) dtype=float32>,
# <tf.Variable 'dense_9/bias:0' shape=(5,) dtype=float32>]
Следовательно, когда мы вызываем model.summary()
, он внутренне вызывает count_params()
для каждого Layer
, что означает все trainable_variable , включая self и атрибуты детей.