Tensorflow 2: пользовательский слой, который инкапсулирует слой - PullRequest
0 голосов
/ 16 марта 2020

Я хочу реализовать пользовательский слой с тензорным потоком 2, подобный тому, который описан в этой статье: https://arxiv.org/abs/1511.00363

Специфика этого пользовательского слоя заключается в том, что веса слоя бинаризованы. (значение -1 или 1) перед прямым проходом.

Для плотного слоя у меня есть следующий код:

class SimpleClip(tf.keras.constraints.Constraint):

    def __init__(self, H=1.0):

        self.min_value = -H
        self.max_value = H

    def __call__(self, p):

        return tf.clip_by_value(p, self.min_value, self.max_value)

    def get_config(self):

        return {"min_value": self.min_value,
                "max_value": self.max_value}

@tf.custom_gradient
def sign_grad(x):

  def grad(dy):
    return dy
  return tf.sign(x), grad

#Binarize layer
class BinaryLayer(tf.keras.layers.Layer):

    def __init__(self,unit,activation):

        super(BinaryLayer, self).__init__(name='my_layer')
        self.unit=unit
        self.activation=activation

    def build(self, input_shape):

        self.kernel_initializer = initializers.RandomUniform(-1, 1)
        self.kernel_constraint=SimpleClip()
        self.kernel = self.add_weight(shape=(input_shape[1],self.unit),trainable=True,initializer=self.kernel_initializer,constraint=self.kernel_constraint,name='kernel')
        self.bias = self.add_weight(shape=(self.unit,),name='bias')


    def call(self,inputs):

        binarize_weight=sign_grad(self.kernel)
        x=tf.matmul(inputs,binarize_weight) + self.bias
        output=self.activation(x)
        return output

Я мог бы сделать тот же процесс для слоя conv2D.

Но я хочу создать пользовательский слой, который принимает в качестве аргумента слой keras и в своей функции вызова бинаризует вес слоя. Аргументом может быть плотный слой или слой Conv2D.

#Binarize layer
class BinaryLayer(tf.keras.layers.Layer):

    def __init__(self,layer):
        super(BinaryLayer, self).__init__(name='my_layer')
        self.layer=layer
    ....

Как я могу это сделать?

...