Один из способов сделать это - использовать tf.stop_gradient
и tf.cond
, чтобы выбрать, следует ли градиентам распространяться обратно в переменную, например, так:
v = tf.Variable(...)
v_is_trainable = tf.placeholder((), tf.bool) # Or tf.placeholder_with_default
v_value = tf.cond(v_is_trainable, lambda: v, lambda: tf.stop_gradient(v))
Тогда простое добавление v_is_trainable: True
или v_is_trainable: False
к feed_dict
сделает переменную обучаемой или нет для этого шага.
РЕДАКТИРОВАТЬ: Если вы хотите выбрать переменные с помощьюуказывая коллекцию обучаемых переменных, вы можете сделать что-то вроде этого
trainable_vars = tf.placeholder([None], tf.string)
v = tf.Variable(...)
# The variable is trainable if its name is in trainable_vars
v_value = tf.cond(tf.reduce_any(tf.equal(v.name, trainable_vars)),
lambda: v, lambda: tf.stop_gradient(v))
Тогда, если у вас есть все переменные, которые вы хотите обучить в коллекции, называемой 'MY_VARS', вы можете указать в feed_dict
:
trainable_vars: [v.name for v in tf.get_collection('MY_VARS')]