Документация действительно сбивает с толку.По умолчанию локальная переменная также добавляется в коллекцию обучаемых переменных.Вы можете проверить это, проверив tf.trainable_variables()
.Итак, похоже, что локальная переменная не обучаема, недостаточно добавить ее в коллекцию LOCAL_VARIABLES
, но вам нужно ключевое слово trainable=False
.
короткий скрипт, который показывает, что локальная и глобальная переменные обновляются в цикле обучения:
import tensorflow as tf
my_local = tf.get_variable("my_local", shape=(), collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.constant_initializer(1.0))
my_global = tf.get_variable("my_global", shape=(),
initializer=tf.constant_initializer(2.0))
target_value = tf.constant(4.0)
loss = tf.abs(my_local + my_global - target_value)
optim = tf.train.AdamOptimizer(learning_rate=1.0).minimize(loss)
for v in tf.trainable_variables():
print(v.name)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
print("local init: ", sess.run(my_local))
print("global init: ", sess.run(my_global))
for i in range(2):
_, l = sess.run([optim, loss])
print("loss {:.4f}".format(l))
print("local: ", sess.run(my_local))
print("global: ", sess.run(my_global))
, который печатает
my_local:0
my_global:0
local init: 1.0
global init: 2.0
loss 1.0000
local: 1.9999996
global: 2.9999995
loss 1.0000
local: 1.9473683
global: 2.9473681
Значение my_local
не изменится, если выустановите trainable=False
в вызове на tf.get_variable
.