Я хочу реализовать пользовательский слой с именем ActNorm (нормализация активации), который требует инициализации параметров из среднего значения пакета и дисперсии пакета входных данных.
Следующий код взят из официальной реализации openai
# Activation normalization
@add_arg_scope
def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False,
trainable=True):
x_shape = int_shape(x)
with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable):
assert len(x_shape) == 2 or len(x_shape) == 3
if len(x_shape) == 2:
x_var = tf.reduce_mean(x ** 2, [0], keepdims=True)
logdet_factor = 1
_shape = [1, x_shape[1]]
elif len(x_shape) == 3:
x_var = tf.reduce_mean(x ** 2, [0, 1], keepdims=True)
logdet_factor = int(x_shape[1])
_shape = [1, 1, x_shape[2]]
if batch_variance:
x_var = tf.reduce_mean(x ** 2, keepdims=True)
logs = get_variable_ddi("logs", _shape, initial_value=tf.log(
scale / (tf.sqrt(x_var) + 1e-6)) / logscale_factor) * logscale_factor
if not reverse:
x = x * tf.exp(logs)
else:
x = x * tf.exp(-logs)
if logdet is not None:
dlogdet = tf.reduce_sum(tf.log(tf.abs(tf.exp(logs)))) * logdet_factor
if reverse:
dlogdet *= -1
return x, logdet + dlogdet
return x
и вот код, который я пытаюсь:
class ActNorm(Layer):
"""Activation normalization(scale and shift)"""
def __init__(self, scale=1.):
super(ActNorm, self).__init__()
self.scale = scale
self.b = None
self.logs = None
def build(self, input_shape):
batch_shape = (1, )*(len(input_shape)-1) + (input_shape[-1], )
self.b = self.add_weight(
name="shift",
shape=batch_shape,
dtype=tf.float32,
initializer='zeros',
trainable=True
)
self.logs = self.add_weight(
name="logscale",
shape=batch_shape,
dtype=tf.float32,
initializer='zeros',
trainable=True
)
super(ActNorm, self).build(input_shape)
def call(self, inputs, **kwargs):
batch_mean, batch_var = tf.nn.moments(inputs, list(range(len(inputs.shape) - 1)), keepdims=True)
self.logs.assign(tf.math.log(self.scale / (tf.sqrt(batch_var) + 1e-6)))
# How to initialize self.b and self.logs only once?
Проблема, с которой я столкнулся, заключается в том, что я не знаю, как определить, что self.logs
и self.b
инициализируется только один раз.