Как правильно реализовать пользовательскую функцию сквоша в TF2.0 (пользовательский слой или другой) - PullRequest
1 голос
/ 03 апреля 2019

Я пытаюсь реализовать простую модель capsnet в TF2.0.

Пока что я добавил несколько слоев conv2d и слой изменения формы, но мне нужно добавить функцию сквоша.Проблема в том, что tf.norm() отправит меня на NaN землю, так как я раздавливаю целые векторы, поэтому я должен использовать пользовательскую функцию сквоша.Я никогда не писал пользовательский слой раньше, и я в основном просто использовал шаблон из учебника и добавил математическую функцию в call().

, так как я делаю все это внутри keras.models.Sequential модели.Я не был уверен, как получить результат после первых нескольких слоев, поэтому я просто решил сделать функцию сквоша своим собственным слоем в модели.Я чувствую, что это, вероятно, совершенно и совершенно неправильно, поэтому я ищу какой-то вклад в лучший способ сделать это.

Должен ли я вообще использовать keras.Model для этого, или я должен использовать новую функцию быстрого исполнения, чтобы просто пропустить тензоры через слои вручную?Если можно использовать SquashLayer(), который я реализовал, то что я передаю в качестве аргумента, чтобы получить правильный вывод для перехода на следующий уровень?

class SquashLayer(tf.keras.layers.Layer):
    def __init__(self, output_units):
        super(SquashLayer, self).__init__()
        self.output_units = output_units

    def build(self, input_shape):
        self.kernel = self.add_variable(
          'kernel', [input_shape[-1], self.output_units])

    def call(self, input):
        squared_norm = tf.reduce_sum(tf.square(input), axis=-1, keepdims=True)
        safe_norm = tf.sqrt(squared_norm + 1e-7)
        squash_factor = squared_norm / (1. + squared_norm)
        unit_vector = input / safe_norm
        return squash_factor * unit_vector

model = keras.models.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation=tf.nn.relu, name='conv1'),
    keras.layers.Conv2D(filters=256, kernel_size=9, strides=2, padding='valid', activation=tf.nn.relu, name='conv2'),
    keras.layers.Reshape((-1, caps1_n_caps, caps1_n_dims)),
    SquashLayer()
    ])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...