Тензор не является элементом этого графика во время загрузки модели Keras после tf.reset_default_graph () - PullRequest
0 голосов
/ 07 августа 2020

У меня есть код, использующий низкоуровневый TF API. Я хочу добавить к нему код с помощью Keras. Я столкнулся с ошибкой crypti c в простейшем сценарии: у меня есть модель Keras, которая загружается и предсказывает правильно. Однако, когда я вызываю tf.reset_default_graph до того, как получаю сообщение об ошибке во время load_model

tf.reset_default_graph()
model = load_model("model.h5")

, я получаю: ValueError: Tensor Tensor («Placeholder: 0», shape = (40, 80), dtype = float32) не является элементом этого графика.

Проблема воспроизводится с помощью следующего минимального кода:

import tensorflow as tf
from keras.models import load_model

model = load_model("model.h5")
model.summary()

# tf.reset_default_graph() OR
tf.keras.backend.clear_session()

model = load_model("model.h5")
model.summary()

1 Ответ

0 голосов
/ 08 августа 2020

Отладка выявила, что проблема заключается в том, что Керас использует сеанс по умолчанию, если он существует, и если в этом сеансе была выполнена некоторая инициализация, сброс графика вызывает путаницу у Кераса, потому что ожидается, что состояние сеанса не изменится, а график сеанса не будет сброшен. Я не видел ничего из этого в документации, и это заставило меня потратить несколько часов на эту проблему. Поэтому, если я хочу загрузить модель, а затем использовать ее несколько раз с вызовами reset_default_graph между ними, мне нужно сохранить сеанс с графиком примерно так:

def load():
    with tf.Graph().as_default() as g:
        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess = tf.Session(graph=g, config=config)
        with sess.as_default():
            model = load_model("model.h5")
            model.summary()
            X = np.random.normal(0, 1, (20,2))
            pred = model.predict(X[np.newaxis])
            print(pred)

            return model, sess

model, sess = load()

with sess.as_default():
    X = np.random.normal(0, 1, (20, 2))
    pred = model.predict(X[np.newaxis])
    print(pred)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...