Тензорные вложения: Как быстро вычислить тензорные вложения-концентраторы? - PullRequest
0 голосов
/ 20 марта 2020

Я работаю над решением, в котором я получаю предложение в качестве ввода, отправляю его в tf-hub для встраивания и возвращаю обратно вложение. Мой код выглядит так:

def get_embedding(model_type, version, sentence):

    model = tf_hub.load_model(model_type, version)
    similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
    similarity_message_encodings = model(similarity_input_placeholder)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())

        message_embeddings = session.run(similarity_message_encodings,
                                        feed_dict={similarity_input_placeholder: sentences})


    return message_embeddings

Но проблема в том, что каждый раз, когда я получаю новые предложения, отправляю в tf-hub его загрузку и вычисляю вывод, что довольно медленно.

Как мне держать сеанс открытым и вычислять вложения внутри сеанса, не загружая каждый раз график и таблицы?

Я думал, что-то вроде этого:

sess =  tf.Session()    
graph=tf.get_default_graph()

    def get_embedding(model_type, version, sentence):

    model = tf_hub.load_model(model_type, version)
    similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
    similarity_message_encodings = model(similarity_input_placeholder)

    with tf.Session() as session:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())

        message_embeddings = sess.run(similarity_message_encodings,
                                        feed_dict={similarity_input_placeholder: sentences})


    return message_embeddings

сделать глобальный sess но это не помогает. Любые предложения о том, как быстро вычислить вложения или как сохранить сеанс открытым, чтобы он не загружал модель каждый раз.

Спасибо!

...