tf.nn.fused_batch_norm
оптимизирован и сделал свое дело.
Мне пришлось создать два подграфа, по одному на режим, поскольку интерфейс fused_batch_norm
не принимает режим условного обучения / тестирования (is_training - это bool, а не тензор, поэтому его график не является условным). Я добавил условие после (см. Ниже). Тем не менее, даже с двумя подграфами, это имеет примерно одинаковое время выполнения tf.layers.batch_normalization
.
Вот окончательное решение (я все равно буду благодарен за любые комментарии или советы по улучшению):
def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
with tf.variable_scope(name):
shape = x.get_shape().as_list()
channels_num = shape[3]
# scale factor
gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
# shift value
beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
(output_train, batch_mean, batch_var) = tf.nn.fused_batch_norm(x,
gamma,
beta, # pylint: disable=invalid-name
mean=None,
variance=None,
epsilon=epsilon,
data_format="NHWC",
is_training=True,
name="_batchnorm_op")
(output_test, _, _) = tf.nn.fused_batch_norm(x,
gamma,
beta, # pylint: disable=invalid-name
mean=moving_mean,
variance=moving_var,
epsilon=epsilon,
data_format="NHWC",
is_training=False,
name="_batchnorm_op")
output = tf.cond(self.is_training, lambda: tf.identity(output_train), lambda: tf.identity(output_test))
update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)
return output