Объединение весов в пользовательском слое Keras с использованием add_weight не удается при вычислении градиентов - PullRequest
1 голос
/ 17 июня 2020

При написании пользовательского слоя мне нужно объединить несколько весовых матриц. Если я сделаю это в функции build, я получу ошибку ValueError: No gradients provided for any variable:..., однако, если я составлю список существ в build и объединю их в call, это сработает. Вот минимальный код для создания ошибки:

class MultiInputLinear(Layer):
    def __init__(self, output_dim=32, n_inputs=2):
        super(Linear, self).__init__()
        self.output_dim = output_dim
        self.n_inputs = n_inputs


    def build(self, input_shapes):
        self.input_dim = input_shapes[0][1]

        self.W = tf.concat(
            [
                self.add_weight(
                    name=f'W_{i}',
                    shape=(self.input_dim, self.output_dim),
                    initializer='random_normal',
                    trainable=True
                ) for i in range(self.n_inputs)
            ], axis=0
        )

    def call(self, inputs):  
        supports = tf.concat(inputs, axis=-1)        
        return tf.matmul(supports, self.W)

N = 100
A = [np.random.normal(size=(N, N)) for _ in range(2)]
y = np.random.binomial(1, .1, size=(N, 32))

A_in = [Input(batch_size=N, shape=(N, )) for _ in range(2)]
Y = MultiInputLinear(y.shape[1], 2)(A_in)

model = Model(inputs=A_in, outputs=Y)
model.compile(loss='categorical_crossentropy', optimizer=Adam())

model.fit(A, y, batch_size=N)

Однако, если в build я сохраняю список, подобный этому:

self.W_list = [
                self.add_weight(
                    name=f'W_{i}',
                    shape=(self.input_dim, self.output_dim),
                    initializer='random_normal',
                    trainable=True
                ) for i in range(self.n_inputs)
            ]

, а затем внутри call я объединяю с ними, как показано ниже, не было бы никаких проблем:

    def call(self, inputs):  
        supports = tf.concat(inputs, axis=-1)
        W = tf.concat(self.W_list, axis=0)

        return tf.matmul(supports, W)

Мне было интересно, в чем причина этого.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...