В TensorFlow, как я могу посмотреть на параметры нормализации партии? - PullRequest
0 голосов
/ 13 ноября 2018

Я использую слой tf.layers.batch_normalization в своей сети. Как вы, возможно, знаете, в пакетной нормализации используются обучаемые параметры гамма и бета для каждой единицы u_i в этом слое, чтобы выбрать собственное стандартное отклонение и среднее значение по u_i (x) для различных входов x. Обычно гамма инициализируется до 1, а бета до 0.

Мне интересно посмотреть на значения гаммы и бета, которые изучаются в различных единицах, чтобы собрать статистику о том, где они, как правило, оказываются после обучения сети. Как я могу посмотреть их текущие значения во время каждого учебного экземпляра?

1 Ответ

0 голосов
/ 13 ноября 2018

Вы можете получить все переменные в области слоя нормализации партии и распечатать их.Пример:

import tensorflow as tf

tf.reset_default_graph()
x = tf.constant(3.0, shape=(3,))
x = tf.layers.batch_normalization(x)

print(x.name) # batch_normalization/batchnorm/add_1:0

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope='batch_normalization')
print(variables)

#[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>,
#  <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    gamma = sess.run(variables[0])
    print(gamma) # [1. 1. 1.]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...