`get_variable ()` не распознает существующие переменные для tf.estimator - PullRequest
0 голосов
/ 26 ноября 2018

Этот вопрос был задан здесь , разница в том, что моя проблема сфокусирована на Estimator.

Некоторый контекст: Мы обучили модель с использованием оценщика и получили некоторую переменную, определенную в Оценщикеinput_fn, эта функция предварительно обрабатывает данные в пакеты.Теперь мы переходим к прогнозированию.Во время прогнозирования мы используем тот же input_fn для считывания и обработки данных. Но получена ошибка о том, что переменная (word_embeddings) не существует (переменные существуют в графе chkp), вот соответствующий бит кода в input_fn:

with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)

, в основном, когда этов режиме прогнозирования else вызывается для загрузки переменных в контрольной точке.Неспособность распознать эту переменную указывает на а) неправильное использование области видимости;б) график не восстановлен.Я не думаю, что здесь важна область видимости, если reuse установлен правильно.

Я подозреваю, что это потому, что график еще не восстановлен на фазе input_fn.Обычно график восстанавливается путем вызова saver.restore(sess, "/tmp/model.ckpt") reference .Расследование оценки исходный код не дает мне ничего, что связано с восстановлением, лучший выстрел - MonitoredSession, оболочка для обучения.Это уже было так много от первоначальной проблемы, не уверен, что я на правильном пути, я ищу здесь помощь, если у кого-то есть какие-либо идеи.

Краткое описание моего вопроса в одной строке: Какграфик восстанавливается в течение tf.estimator, через input_fn или model_fn?

1 Ответ

0 голосов
/ 12 декабря 2018

Здравствуйте, я думаю, что ваша ошибка возникает просто потому, что вы не указали фигуру в tf.get_variable (при прогнозировании), похоже, вам нужно указать фигуру, даже если переменная будет восстановлена.

Я провел следующий тест с простой оценкой линейного регрессора, которой просто нужно предсказать x + 5

def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')

Этот код работает отлично, значение const равно 20, а также длявесело использовать его в моем тестовом наборе, чтобы подтвердить: p

Однако, если вы удалите форму = [], она сломается, вы также можете дать другой инициализатор, такой как tf.constant (500), и все будет работать и 20будет использоваться.

Запустив

model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)

и

preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))

Вы можете визуализировать график, и вы увидите, что а) область видимости нормальная и б) график восстановлен.

Надеюсь, это поможет вам.

...