Потокобезопасное использование графов в Tensorflow - PullRequest
0 голосов
/ 10 января 2019

У меня есть приложение для колб, которое сначала загружает модель keras, а затем выполняет функцию прогнозирования. Согласно этому ответу я сохраняю график тензорного потока в глобальной переменной и использую тот же график для функции прогнозирования.

def load_model():
    load_model_from_file()
    global graph
    graph = tf.get_default_graph()

def predict():
    with graph.as_default():
        tagger = Tagger(self.model, preprocessor=self.p)
        return tagger.analyze(words)

@app.route('/predict', methods=['GET'])
def load_and_predict():
    load_model()
    predict()

Однако это приводит к проблеме, когда на сервер отправляется несколько запросов. Как я могу сделать этот код поточно-ориентированным или, если быть более точным, как правильно использовать тензорные графы в многопоточной среде?

Ответы [ 2 ]

0 голосов
/ 10 января 2019

Вы можете синхронизировать его с блокировкой.

import threading    
lock = threading.Lock()

def load_and_predict():
     with lock:
        load_model()
        predict()
0 голосов
/ 10 января 2019

обычно вы должны использовать сеанс при работе с потоками в tenorflow.

intra_parallel_thread_tf = 1
inter_parallel_thread_tf = 1

session_conf = tf.ConfigProto(intra_op_parallelism_threads=intra_parallel_thread_tf,
                          inter_op_parallelism_threads=inter_parallel_thread_tf)

tf.Session(graph=tf.get_default_graph(), config=session_conf)
GRAPH = tf.get_default_graph()

Но это довольно общее. Это также зависит от ошибки, которую вы получаете.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...