Является ли локальная переменная обучаемой по умолчанию или нет? - PullRequest
0 голосов
/ 20 февраля 2019

Когда я гуляю по руководству https://www.tensorflow.org/guide/variables,, меня смущает приведенное ниже описание (жирный шрифт):

По умолчанию каждый tf.Variable помещается в следующие две коллекции:

  • tf.GraphKeys.GLOBAL_VARIABLES --- переменные, которые могут совместно использоваться несколькими устройствами,
  • tf.GraphKeys.TRAINABLE_VARIABLES --- переменные, для которых TensorFlow будет вычислять градиенты.

Если вы не хотите, чтобы переменная была обучаемой , добавьте ее в коллекцию tf.GraphKeys.LOCAL_VARIABLES.Например, следующий фрагмент демонстрирует, как добавить переменную с именем my_local в эту коллекцию:

my_local = tf.get_variable("my_local", shape=(), collections [tf.GraphKeys.LOCAL_VARIABLES])`

В качестве альтернативы вы можете указать trainable=False в качестве аргумента для tf.get_variable:

my_non_trainable = tf.get_variable("my_non_trainable", shape=(), trainable=False)

Но когда я создаю локальную переменную, она автоматически добавляется в коллекцию tf.GraphKeys.TRAINABLE_VARIABLES, что означает, что она обучаема.Итак, обучаемая локальная переменная или нет?

1 Ответ

0 голосов
/ 20 февраля 2019

Документация действительно сбивает с толку.По умолчанию локальная переменная также добавляется в коллекцию обучаемых переменных.Вы можете проверить это, проверив tf.trainable_variables().Итак, похоже, что локальная переменная не обучаема, недостаточно добавить ее в коллекцию LOCAL_VARIABLES, но вам нужно ключевое слово trainable=False.

короткий скрипт, который показывает, что локальная и глобальная переменные обновляются в цикле обучения:

import tensorflow as tf

my_local = tf.get_variable("my_local", shape=(), collections=[tf.GraphKeys.LOCAL_VARIABLES],
                           initializer=tf.constant_initializer(1.0))
my_global = tf.get_variable("my_global", shape=(),
                            initializer=tf.constant_initializer(2.0))

target_value = tf.constant(4.0)
loss = tf.abs(my_local + my_global - target_value)
optim = tf.train.AdamOptimizer(learning_rate=1.0).minimize(loss)

for v in tf.trainable_variables():
    print(v.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print("local init: ", sess.run(my_local))
    print("global init: ", sess.run(my_global))
    for i in range(2):
        _, l = sess.run([optim, loss])
        print("loss {:.4f}".format(l))
        print("local: ", sess.run(my_local))
        print("global: ", sess.run(my_global))

, который печатает

my_local:0
my_global:0
local init:  1.0
global init:  2.0
loss 1.0000
local:  1.9999996
global:  2.9999995
loss 1.0000
local:  1.9473683
global:  2.9473681

Значение my_local не изменится, если выустановите trainable=False в вызове на tf.get_variable.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...