Условный пакет нормализации Keras - PullRequest
0 голосов
/ 21 апреля 2020

Я хочу реализовать слой условной пакетной нормализации в Керасе. Я внес изменения в код для слоя BatchNormalizaton, чтобы разрешить несколько гамма- и бета-переменных, по одной для каждого класса. Вот код, который я написал:

from keras import regularizers, initializers, constraints
from keras.legacy import interfaces
import keras.backend as K
from keras.layers import Layer, Input, InputSpec
from keras.models import Model
import tensorflow as tf

import random
import numpy as np
from keras.optimizers import Adam
from keras.layers import BatchNormalization

class ConditionalBatchNormalization(Layer):
    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 moving_mean_initializer='zeros',
                 moving_variance_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 num_classes=1,
                 **kwargs):
        super(ConditionalBatchNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = (
            initializers.get(moving_variance_initializer))
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)
        self.num_classes = num_classes

    def build(self, input_shape):
        dim = input_shape[0][self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape[0]) + '.')
        self.input_spec = [InputSpec(ndim=len(input_shape[0]),
                                    axes={self.axis: dim}),
                            InputSpec(ndim=len(input_shape[1]))]
        shape = (dim,)

        if self.scale:
            self.gamma = [self.add_weight(shape=shape,
                                         name='gamma_'+str(i),
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
                            for i in range(self.num_classes)]
        else:
            self.gamma = [None for i in range(self.num_classes)]
        if self.center:
            self.beta = [self.add_weight(shape=shape,
                                        name='beta_'+str(i),
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
                            for i in range(self.num_classes)]
        else:
            self.beta = [None for i in range(self.num_classes)]

        self.moving_mean = self.add_weight(
            shape=shape,
            name='moving_mean',
            initializer=self.moving_mean_initializer,
            trainable=False)
        self.moving_variance = self.add_weight(
            shape=shape,
            name='moving_variance',
            initializer=self.moving_variance_initializer,
            trainable=False)
        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs[0])
        input_class = inputs[1][0]
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(tf.gather(self.beta,input_class), 
                                               broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(tf.gather(self.gamma,input_class),
                                                broadcast_shape)
                else:
                    broadcast_gamma = None
                return K.batch_normalization(
                    inputs[0],
                    broadcast_moving_mean,
                    broadcast_moving_variance,
                    broadcast_beta,
                    broadcast_gamma,
                    #axis=self.axis,
                    epsilon=self.epsilon)
            else:
                return K.batch_normalization(
                    inputs[0],
                    self.moving_mean,
                    self.moving_variance,
                    tf.gather(self.beta,input_class),
                    tf.gather(self.gamma,input_class),
                    #axis=self.axis,
                    epsilon=self.epsilon)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs[0], tf.gather(self.gamma,input_class), 
            tf.gather(self.beta,input_class), 
            reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs[0])[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs[0]))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([K.moving_average_update(self.moving_mean,
                                                 mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 variance,
                                                 self.momentum)],
                        inputs[0])

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)

    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'num_classes':self.num_classes,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'moving_mean_initializer':
                initializers.serialize(self.moving_mean_initializer),
            'moving_variance_initializer':
                initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(ConditionalBatchNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape[0]


if __name__ == '__main__':
    x = Input((10,))
    c = Input(shape=(1,),dtype=tf.int32)
    h = ConditionalBatchNormalization(num_classes=3)([x, c])
    model = Model([x, c], h)
    model.summary()
    model.compile(optimizer=Adam(1e-4), loss='mse')

    C = np.ones((100,1))*0    #*1 *2
    X = np.random.rand(100, 10)
    Y = np.random.rand(100, 10)

    weights_before = model.layers[2].get_weights()

    model.train_on_batch(x=[X, C], y=Y)

    weights_after = model.layers[2].get_weights()

Если я изменю условие, C, обновятся соответствующие гамма и бета переменные. Однако переменные moving_mean и moving_average не изменяются после обучения. Я что-то упустил?

...