Обтекаемый кеш, обученный керасом модель - PullRequest
0 голосов
/ 04 апреля 2020

Я обучил модель (через среду Keras), экспортировал ее с model.save('model.hdf5') и теперь хочу интегрировать ее с потрясающим Streamlit. Очевидно, я не хочу загружать модель каждый раз, когда конечный пользователь вставляет новый вход, но загружать ее раз и навсегда. поэтому мой код выглядит примерно так:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()

    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

Таким образом, я получил:

"streamlit.errors.UnhashableType:"

исключение , Я пытался использовать @st.cache(allow_output_mutation=True) и когда я запускаю запрос на странице с подсветкой. Я получил:

"Ошибка типа: невозможно интерпретировать ключ feed_dict как Tensor: Тензор Tensor (" input_1: 0 ", shape = (?, 80), dtype = int32) не является элементом этого графа . "

(Конечно, без каких-либо декораторов кеша модель загружается и работает нормально)

КАК правильно загрузить и кешировать обученную модель Keras ?

  • Python ver: 2.7 (к сожалению)
  • Keras ver: 2.1.3
  • Tensorflow ver: 1.3.0
  • Streamlit вер: 0.55.2

Большое спасибо!

1 Ответ

0 голосов
/ 11 апреля 2020

Решение было:

  1. добавление _make_predict_function() вызов
  2. возврат сеанса
from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)
...