Раствор 1
Вам нужно передать аргумент инициализатора для get_variable
, как
import tensorflow as tf
import numpy as np
ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))
block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]
with tf.Session() as sess:
sess.run(tf.variables_initializer(var_list=block_vars))
print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')
try:
print(np.mean(sess.run([ref4])))
except Exception as e:
print('[INFO] failed as expected with message %s' % e)
sess.run(tf.variables_initializer(var_list=block_vars_complement))
print(np.mean(sess.run([ref4, ref5])), 'should be ~ 10')
Раствор 2
Если вы не хотите передавать initializer
каждому get_variable
, вы можете использовать пользовательский метод получения, такой как
import tensorflow as tf
import numpy as np
def my_getter(getter, name, shape, *args, **kwargs):
if 'block' not in name:
return getter(name=name, shape=shape, *args, **kwargs)
else:
kwargs['initializer'] = tf.truncated_normal_initializer(mean=40)
return getter(name=name, shape=shape, *args, **kwargs)
with tf.variable_scope("some_scopename", custom_getter=my_getter):
ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))
block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]
with tf.Session() as sess:
sess.run(tf.variables_initializer(var_list=block_vars))
print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')
try:
print(np.mean(sess.run([ref4])))
except Exception as e:
print('[INFO] failed as expected with message %s' % e)
sess.run(tf.variables_initializer(var_list=block_vars_complement))
print(np.mean(sess.run([ref4, ref5])), 'should be ~ 10')
Решение 3
tf.truncated_normal_initializer
или другие инициализаторы - это просто самостоятельные операции. Следовательно, они могут быть применены в цикле ко всем переменным из коллекции, и такое групповое обновление может быть наконец применено (см. initialize_collection
):
import tensorflow as tf
import numpy as np
ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))
block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]
def initialize_collection(collection, initializer):
ops = []
for v in collection:
ops.append(v.assign(initializer(shape=v.shape)))
return tf.group(ops)
with tf.Session() as sess:
sess.run(tf.variables_initializer(var_list=block_vars))
print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')
sess.run(initialize_collection(block_vars, tf.truncated_normal_initializer(mean=-40, stddev=0.01)))
print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ -40')