Хотите выбрать переменные для обучения во время выполнения через заполнитель - PullRequest
1 голос
/ 17 октября 2019

Я ищу способ выбора обучаемых переменных для обновления во время выполнения на основе моего идентификатора эпохи. Я понимаю, что я могу назвать набор переменных в области видимости переменной как tf.variable_scope.

Если я создаю заполнитель как train_vars = tf.placeholder(shape = [None], dtype = type(tf.GraphKeys)), он выдаст следующую ошибку:

TypeError: Expected DataType for argument 'dtype' not <class 'type'>.

Чтоправильный способ передать этот список обучаемых через заполнитель или нет пути?

1 Ответ

0 голосов
/ 17 октября 2019

Один из способов сделать это - использовать tf.stop_gradient и tf.cond, чтобы выбрать, следует ли градиентам распространяться обратно в переменную, например, так:

v = tf.Variable(...)
v_is_trainable = tf.placeholder((), tf.bool)  # Or tf.placeholder_with_default
v_value = tf.cond(v_is_trainable, lambda: v, lambda: tf.stop_gradient(v))

Тогда простое добавление v_is_trainable: True или v_is_trainable: False к feed_dict сделает переменную обучаемой или нет для этого шага.

РЕДАКТИРОВАТЬ: Если вы хотите выбрать переменные с помощьюуказывая коллекцию обучаемых переменных, вы можете сделать что-то вроде этого

trainable_vars = tf.placeholder([None], tf.string)
v = tf.Variable(...)
# The variable is trainable if its name is in trainable_vars
v_value = tf.cond(tf.reduce_any(tf.equal(v.name, trainable_vars)),
                  lambda: v, lambda: tf.stop_gradient(v))

Тогда, если у вас есть все переменные, которые вы хотите обучить в коллекции, называемой 'MY_VARS', вы можете указать в feed_dict:

trainable_vars: [v.name for v in tf.get_collection('MY_VARS')]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...