keras определяют обучаемую переменную для add или matmul - PullRequest
0 голосов
/ 11 июля 2019

У меня есть некоторые проблемы при использовании tf.keras для построения модели.Теперь я хочу определить тензор веса поезда с формой (64, 128), которая похожа на tf.get_variable.Однако я не могу достичь этого.

В прошлом я пробовал много методов. Но я хочу легко найти метод.

inputs = tf.keras.Input((128,))
weights = tf.Variable(tf.random.normal((64, 128)))
output = tf.keras.layers.Lambda(lambda x: tf.matmul(x, tf.transpose(weights)))(inputs)
model = tf.keras.Model(inputs, output)
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        (None, 128)               0         
_________________________________________________________________
lambda_2 (Lambda)            (None, 64)                0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0

Определенные веса не поддаются обучению.
Кроме того, я знаю, что Dense может получить обученные веса матриц и смещения.Но если я хочу добавить смещение, я не могу использовать Dense.
Тем не менее, я должен использовать add_weights на уровне обычая, например:

class Bias(keras.layers.Layer):
    def build(self, input_shape):
        self.bias = self.add_weight(shape=(64, 128), initializer='zeros', dtype=tf.float32, name='x')
        self.built = True

    def call(self, inputs):
        return inputs + self.bias

inputs = Input(shape=(64, 128))
outputs = Bias()(inputs)
model = Model(inputs=inputs, outputs=outputs)
model.summary()
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        (None, 64, 128)           0         
_________________________________________________________________
bias_5 (Bias)                (None, 64, 128)           8192      
=================================================================
Total params: 8,192
Trainable params: 8,192
Non-trainable params: 0

Есть ли более простой способ определитьобучаемая переменная?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...