Keras не работает хорошо в многопоточной среде - PullRequest
0 голосов
/ 26 января 2019

У меня возникают проблемы при использовании модели keras в многопоточном задании в python. Мой код выглядит примерно так:

def load_my_model():
  from keras.models import load_model
  global model
  model = load_model(path_to_model)

def get_test_data(userid):
   ....
   return test_data_list  # returns a list of test data points for user

def predict_for_user(userid):
  load_my_model()
  test_data_list = get_test_data(userid)
  for el in test_data_list:
       model.predict(el)

from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list) 
pool.close()
pool.join()

Раньше это работало нормально, когда выполнялось в цикле for, но не при использовании функции пула из многопроцессорной обработки. Выдает ошибку:

Невозможно интерпретировать ключ feed_dict как Tensor: Tensor Tensor ("Placeholder: 0", shape = (300, 120), dtype = float32) не является элементом этого графика.

После некоторых онлайн-предложений я изменил код на:

def load_my_model():
  from keras.models import load_model
  import tensorflow as tf

  global model
  model = load_model(path_to_model)

  global graph
  graph = tf.get_default_graph()


def get_test_data(userid):
   ....
   return test_data_list  # returns a list of test data points for user

def predict_for_user(userid):
  load_my_model()
  test_data_list = get_test_data(userid)
  for el in test_data_list:
       with graph.as_default():
          model.predict(el)

from multiprocessing.pool import ThreadPool as Pool
pool = Pool(4)
result = pool.map(predict_for_user, user_id_list) 
pool.close()
pool.join()

Это помогло как-то, поскольку теперь модель хорошо предсказывает пользователя, не выдавая ошибку, но похоже, что при переходе к следующему пользователю в списке пула выдает ту же ошибку, что и раньше:

Невозможно интерпретировать ключ feed_dict как Tensor: Tensor Tensor ("Placeholder: 0", shape = (300, 120), dtype = float32) не является элементом этого графа.

Я уверен, что это как-то связано с тем, что тензоры не загружаются должным образом во время многопоточности, но не уверен, как это исправить в моем коде выше.

Любая помощь была бы так признательна!

...