Вывод модели keras / tenorflow в многопоточном env генерирует ошибку «не элемент этого графа» - PullRequest
0 голосов
/ 03 июля 2019

Я работаю над обслуживанием обученной модели keras через REST API с помощью веб-сервера flask, но обнаружил ошибку «не элемент этого графика».Я думаю, что это связано с многопоточной средой или проблемой безопасности потока, поэтому я нашел следующий код, чтобы найти проблему:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import threading
import tensorflow as tf

model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

print(threading.current_thread(), tf.get_default_graph())

def thread_fn():
    print(threading.current_thread(), tf.get_default_graph())
    model.predict(np.array([[0,0,0,0,0,0,0,0]]))

th = threading.Thread(target=thread_fn)
th.start()
th.join()

Запуск вышеуказанного кода приводит к той же ошибке, когда я работаюмодель с использованием колбы.Однако я также заметил, что график тензорного потока отличается при загрузке модели и при прогнозировании.

<_MainThread(MainThread, started 4532807104)> <tensorflow.python.framework.ops.Graph object at 0x101c3e240>
<Thread(Thread-1, started 123145439219712)> <tensorflow.python.framework.ops.Graph object at 0xb33df5d30>

Поэтому я попробовал следующее решение:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import threading
import tensorflow as tf

model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

grpah = tf.get_default_graph()
print(threading.current_thread(), tf.get_default_graph())

def thread_fn():
    with grpah.as_default():
        print(threading.current_thread(), tf.get_default_graph())
        model.predict(np.array([[0,0,0,0,0,0,0,0]]))

th = threading.Thread(target=thread_fn)
th.start()
th.join()

Я сохраняю график при загрузке модели и использую этот график при прогнозировании.На этот раз ошибка исчезает.

Может быть, это проблема с графиком?Тем не менее, я обнаружил, что если я запускаю model.predict() в главном потоке, ошибка также исчезает, но график все равно отличается при загрузке модели и при прогнозировании.

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import threading
import tensorflow as tf

model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

model.predict(np.array([[0,0,0,0,0,0,0,0]]))
print(threading.current_thread(), tf.get_default_graph())

def thread_fn():
    print(threading.current_thread(), tf.get_default_graph())
    model.predict(np.array([[0,0,0,0,0,0,0,0]]))

th = threading.Thread(target=thread_fn)
th.start()
th.join()

вывод:

<_MainThread(MainThread, started 4416488896)> <tensorflow.python.framework.ops.Graph object at 0x1040a5278>
<Thread(Thread-1, started 123145542557696)> <tensorflow.python.framework.ops.Graph object at 0xb36258cf8>

Так какова реальная причина, вызывающая эту ошибку?Какой из приведенных выше методов действительно решает эту проблему?Есть ли другое решение? Этот ответ говорит мне, что я должен использовать model._make_predict_function(), но я не вижу никакого эффекта от этой функции.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...