Проверено на TF 1.15
В основном ошибка говорит о том, что он находит несколько ссылок на одну и ту же переменную в диктовке restore_variables
. Исправить это просто. Создайте копию своей переменной, используя tf.Variable(varr)
, как показано ниже для одной из ссылок.
Я думаю, можно с уверенностью предположить, что вы не ищете здесь несколько ссылок на одну и ту же переменную, а скорее две отдельные переменные. (Я предполагаю это, потому что, если вы хотите использовать одну и ту же переменную несколько раз, вы можете просто использовать одну переменную несколько раз).
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name.split("_red")[0]] = varr
restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
Ниже вы можете найти полный код для репликации проблемы используя игрушечный пример. По сути, у нас есть две переменные a
и b
, и из этого мы создаем b_red
и b_blue
переменные.
# Saving the variables
import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)
saver = tf.train.Saver([w1, w2])
with tf.Session() as sess:
tf.global_variables_initializer().run()
saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
# Restoring the variables
with tf.Session() as sess:
saver.restore(sess, './vars/vars.ckpt-0')
restore_variables = {}
checkpoint_variables=['b']
for varr in tf.global_variables():
if varr.op.name in checkpoint_variables:
restore_variables[varr.op.name+"_red"] = varr
# Fixing the issue: Instead of varr, do tf.Variable(varr)
restore_variables[varr.op.name+"_blue"] = varr
print(restore_variables)
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)