Как передать веса предыдущих слоев в качестве входных данных для функции вызова слоя custum в функциональной модели Keras? - PullRequest
1 голос
/ 24 января 2020

Для метода вызова моего пользовательского слоя мне нужны веса некоторых прецедентных слоев, но мне не нужно изменять их только для доступа к их значению. У меня есть значение, предложенное в Как получить вес слоя в Керасе? , но это возвращает вес в виде массива numpy. Поэтому я произвел их в Tensor (используя tf.convert_to_tensor из бэкэнда Keras), но в момент создания модели у меня появилась эта ошибка: «Объект NoneType не имеет атрибута _inbound_nodes». Как я могу решить эту проблему? Спасибо тебе.

Ответы [ 2 ]

1 голос
/ 24 января 2020

Вы можете пропустить этот прецедентный слой при инициализации своего пользовательского класса слоя.

Пользовательский слой:

class CustomLayer(Layer):
    def __init__(self, reference_layer):
      super(CustomLayer, self).__init__()
      self.ref_layer = reference_layer # precedent layer

    def call(self, inputs):
        weights = self.ref_layer.get_weights()
        ''' do something with these weights '''
        return something

Теперь вы добавляете этот слой в модель, используя Functional-API .

inp = Input(shape=(5))
dense = Dense(5)
custom_layer= CustomLayer(dense) # pass layer here

#model
x = dense(inp)
x = custom_layer(x)
model = Model(inputs=inp, outputs=x)

Здесь custom_layer может получить доступ к весам слоя dense.

1 голос
/ 24 января 2020

TensorFlow предоставляет коллекции графов, которые группируют переменные. Чтобы получить доступ к обученным переменным, вы должны вызвать tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) или его сокращение tf.trainable_variables() или получить все переменные (включая некоторые для статистики), используя tf.get_collection(tf.GraphKeys.VARIABLES) или его сокращение tf.all_variables()

tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)

for var, val in zip(tvars, tvars_vals):
    print(var.name, val)  # Prints the name of the variable alongside its value.
...