Как применить регуляризацию ядра в настраиваемом слое в Keras / TensorFlow? - PullRequest
0 голосов
/ 06 августа 2020

Рассмотрим следующий код настраиваемого слоя из учебника TensorFlow:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

Как применить любую предопределенную регуляризацию (например, tf.keras.regularizers.L1) или настраиваемую регуляризацию к параметрам настраиваемого слоя?

1 Ответ

0 голосов
/ 06 августа 2020

Метод add_weight принимает аргумент regularizer, который можно использовать для применения регуляризации к весу. Например:

self.kernel = self.add_weight("kernel",
                               shape=[int(input_shape[-1]), self.num_outputs],
                               regularizer=tf.keras.regularizers.l1_l2())

В качестве альтернативы, чтобы иметь больше контроля, как другие встроенные слои, вы можете изменить определение настраиваемого слоя и добавить аргумент kernel_regularizer в метод __init__:

from tensorflow.keras import regularizers

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs, kernel_regularizer=None):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs
    self.kernel_regularizer = regularizers.get(kernel_regularizer)

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs],
                                  regularizer=self.kernel_regularizer)

При этом вы даже можете передать строку типа 'l1' или 'l2' в kernel_regularizer аргумент при построении слоя, и это будет правильно разрешено.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...