Как установить параметр осей в моментах TensorFlow для нормализации партии? - PullRequest
1 голос
/ 07 ноября 2019

Я планирую реализовать функцию нормализации партии, аналогичную этому блогу (или просто использовать tf.nn.batch_normalization) с использованием tf.nn.moments для вычисления среднего идисперсия, но я хочу сделать это для временных данных, как вектор, так и тип изображения. Обычно у меня возникают небольшие проблемы с пониманием того, как правильно установить аргумент axes в tf.nn.moments.

Мои входные данные для векторных последовательностей имеют форму (batch, timesteps, channels), а мои входные данные для последовательностей изображений имеют форму (batch, timesteps, height, width, 3) (обратите внимание, что это изображения RGB). В обоих случаях я хочу, чтобы нормализация происходила по всему пакету и по временным шагам, то есть я не пытаюсь сохранить отдельное среднее значение / дисперсию для разных временных шагов.

Как правильно установить axes для разных типов данных (например, изображения, вектора) и для временных / не временных?

1 Ответ

1 голос
/ 07 ноября 2019

Самый простой способ думать об этом - оси, переданные в axes, будут свернуты , а статистика будет вычислена путем разрезания по axes. Пример:

import tensorflow as tf

x = tf.random.uniform((8, 10, 4))

print(x, '\n')
print(tf.nn.moments(x, axes=[0]), '\n')
print(tf.nn.moments(x, axes=[0, 1]))
Tensor("random_uniform:0", shape=(8, 10, 4), dtype=float32)

(<tf.Tensor 'moments/Squeeze:0'   shape=(10, 4) dtype=float32>,
 <tf.Tensor 'moments/Squeeze_1:0' shape=(10, 4) dtype=float32>)

(<tf.Tensor 'moments_1/Squeeze:0'   shape=(4,) dtype=float32>,
 <tf.Tensor 'moments_1/Squeeze_1:0' shape=(4,) dtype=float32>)

Из источника math_ops.reduce_mean используется для вычисления как mean, так и variance, который работает как в псевдокоде:

# axes = [0]
mean = (x[0, :, :] + x[1, :, :] + ... + x[7, :, :]) / 8
mean.shape == (10, 4)  # each slice's shape is (10, 4), so sum's shape is also (10, 4)

# axes = [0, 1]
mean = (x[0, 0,  :] + x[1, 0,  :] + ... + x[7, 0,  :] +
        x[0, 1,  :] + x[1, 1,  :] + ... + x[7, 1,  :] +
        ... +
        x[0, 10, :] + x[1, 10, :] + ... + x[7, 10, :]) / (8 * 10)
mean.shape == (4, ) # each slice's shape is (4, ), so sum's shape is also (4, )

Другими словами, axes=[0] будет вычислять (timesteps, channels) статистику по samples - т.е. итерировать по samples, вычислять среднее значение и дисперсию (timesteps, channels) срезов. Таким образом, для

нормализация должна происходить по всему пакету и по временным шагам, то есть я не пытаюсь сохранить отдельное среднее значение / дисперсию для разных временных шагов

вам просто нужно свернуть измерение timesteps (вдоль samples) и вычислить статистику, выполнив итерации для samples и timesteps:

axes = [0, 1]

Та же история для изображений, за исключениему вас есть два неканальных / семпловых измерения, вы должны сделать axes = [0, 1, 2] (чтобы свернуть samples, height, width).


Демонстрация псевдокода : посмотреть среднее вычисление в действии

import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np

x = tf.constant(np.random.randn(8, 10, 4))
result1 = tf.add(x[0], tf.add(x[1], tf.add(x[2], tf.add(x[3], tf.add(x[4], 
                       tf.add(x[5], tf.add(x[6], x[7]))))))) / 8
result2 = tf.reduce_mean(x, axis=0)
print(K.eval(result1 - result2))
# small differences per numeric imprecision
[[ 2.77555756e-17  0.00000000e+00 -5.55111512e-17 -1.38777878e-17]
 [-2.77555756e-17  2.77555756e-17  0.00000000e+00 -1.38777878e-17]
 [ 0.00000000e+00 -5.55111512e-17  0.00000000e+00 -2.77555756e-17]
 [-1.11022302e-16  2.08166817e-17  2.22044605e-16  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [-5.55111512e-17  2.77555756e-17 -1.11022302e-16  5.55111512e-17]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -2.77555756e-17]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -5.55111512e-17]
 [ 0.00000000e+00 -3.46944695e-17 -2.77555756e-17  1.11022302e-16]
 [-5.55111512e-17  5.55111512e-17  0.00000000e+00  1.11022302e-16]]
...