Это из-за разницы между переменной и get_variable.
Существует два способа решения этой проблемы:
1) введите переменную, отличную от ее имени.
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': b})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
Поскольку, если это переменная, тензор потока может получить его напрямую :
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
В противном случае он должен получить переменную из хранилища переменных :
store_vars = vs._get_default_variable_store()._vars
Но переменная, определенная переменной, отсутствует в коллекции ('varstore_key',)
, как объясняется в этом ответе .
Затем 2) вы можете добавить его в коллекцию самостоятельно:
from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
store = _VariableStore()
store._vars = {'b': b}
ops.add_to_collection(('__variable_store',), store)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'b'})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
Обе работы.