Здравствуйте, я думаю, что ваша ошибка возникает просто потому, что вы не указали фигуру в tf.get_variable (при прогнозировании), похоже, вам нужно указать фигуру, даже если переменная будет восстановлена.
Я провел следующий тест с простой оценкой линейного регрессора, которой просто нужно предсказать x + 5
def input_fn(mode):
def _input_fn():
with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
if mode == tf.estimator.ModeKeys.TRAIN:
var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
x_data = np.random.randn(1000)
labels = x_data + 5
return {'x':x_data}, labels
elif mode == tf.estimator.ModeKeys.PREDICT:
var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
return {'x':[0,10,100,var_to_follow]}
return _input_fn
featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')
Этот код работает отлично, значение const равно 20, а также длявесело использовать его в моем тестовом наборе, чтобы подтвердить: p
Однако, если вы удалите форму = [], она сломается, вы также можете дать другой инициализатор, такой как tf.constant (500), и все будет работать и 20будет использоваться.
Запустив
model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)
и
preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))
Вы можете визуализировать график, и вы увидите, что а) область видимости нормальная и б) график восстановлен.
Надеюсь, это поможет вам.