Как реализовать CRelu в Керасе? - PullRequest
0 голосов
/ 27 февраля 2019

Я пытаюсь реализовать слой CRelu в Keras.

Один вариант, который, кажется, работает, - это использовать лямбда-слой:

def _crelu(x):
    x = tf.nn.crelu(x, axis=-1)
    return x

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Lambda(_crelu)(x)
    return x

Но мне интересно, что слой Lamda вводит некоторые издержки в обученииили процесс вывода?

Моя вторая попытка - создать слой keras, который будет обернут вокруг tf.nn.crelu

class CRelu(Layer):
    def __init__(self, **kwargs):
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=-1)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = CRelu()(x)
    return x

Какая версия будет более эффективной?

Также с нетерпением ждудля чистой реализации Keras, если это возможно.

1 Ответ

0 голосов
/ 12 мая 2019

Я не думаю, что есть существенная разница между двумя реализациями по скорости.

Реализация Lambda на самом деле самая простая, но обычно лучше написать пользовательский слой, как вы это делали, особенно в том, что касается сохранения и загрузки модели ( get_config метод).

Но в этом случае это не имеет значения, поскольку CReLU тривиален и не требует сохранения и восстановления параметров.Вы можете сохранить параметр оси фактически как в коде ниже.Таким образом, он будет получен автоматически при загрузке модели.

class CRelu(Layer):
    def __init__(self, axis=-1, **kwargs):
        self.axis = axis 
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=self.axis)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

    def get_config(self, input_shape):
        config = {'axis': self.axis, }
        base_config = super(CReLU, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...