Повторное использование модели Tensorflow из другого класса - PullRequest
0 голосов
/ 28 апреля 2020

Название может быть непонятным, но вот сценарий:

  • У меня есть предварительно обученная модель,
  • Я собираю все, чтобы выступать в качестве службы со слушателем это получит CloudEvents (для упрощения: сообщения), которые дают имя анализируемого изображения,
  • Поскольку модель довольно тяжелая (> 1 ГБ), я хочу загрузить ее при инициализации и делать прогнозы по мере поступления событий

Вот упрощенный код:

import http.server
import io
import json
import logging
import socketserver
import sys

import cv2
import numpy as np
import tensorflow as tf

from cloudevents.sdk import marshaller
from cloudevents.sdk.event import v02

# Set logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)


mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}

meta_url = './COVIDNet-CXR-Large/model.meta'
ckpt_url = './COVIDNet-CXR-Large/model-8485'


# Events listener
m = marshaller.NewDefaultHTTPMarshaller()

class ForkedHTTPServer(socketserver.ForkingMixIn, http.server.HTTPServer):
    """Handle requests with fork."""


class CloudeventsServer(object):
    """Listen for incoming HTTP cloudevents requests.
    cloudevents request is simply a HTTP Post request following a well-defined
    of how to pass the event data.
    """
    def __init__(self, port=8888):
        self.port = port

    def start_receiver(self, func):
        """Start listening to HTTP requests
        :param func: the callback to call upon a cloudevents request
        :type func: cloudevent -> none
        """
        class BaseHttp(http.server.BaseHTTPRequestHandler):
            def do_POST(self):
                logging.info('POST received')
                content_type = self.headers.get('Content-Type')
                content_len = int(self.headers.get('Content-Length'))
                headers = dict(self.headers)
                data = self.rfile.read(content_len)
                data = data.decode('utf-8')
                logging.info(content_type)
                logging.info(data)

                if content_type != 'application/json':
                    logging.info('Not JSON')
                    data = io.StringIO(data)

                try:
                    event = v02.Event()
                    event = m.FromRequest(event, headers, data, json.loads)
                except Exception as e:
                    logging.error(f"Event error: {e}")
                    raise   

                logging.info(event)

                # Send ack first to free connection
                self.send_response(204)
                self.end_headers()

                func(event)

                return

        socketserver.TCPServer.allow_reuse_address = True
        with ForkedHTTPServer(("", self.port), BaseHttp) as httpd:
            try:
                logging.info("serving at port {}".format(self.port))
                httpd.serve_forever()
            except:
                httpd.server_close()
                raise

class Model(object):

    def __init__(self,meta_url,ckpt_url):
        self.sess = tf.Session()
        tf.get_default_graph()

        saver = tf.train.import_meta_graph(meta_url)
        saver.restore(self.sess,ckpt_url)

        graph = tf.compat.v1.get_default_graph()

        self.image_tensor = graph.get_tensor_by_name("input_1:0")
        self.pred_tensor = graph.get_tensor_by_name("dense_3/Softmax:0")


    def prediction(self,key):
        logging.info('start prediction')

        x = cv2.imread(key)

        h, w, c = x.shape
        x = x[int(h/6):, :]
        x = cv2.resize(x, (224, 224))
        x = x.astype('float32') / 255.0

        # Make prediction
        logging.info('make prediction')
        pred = self.sess.run(self.pred_tensor, feed_dict={self.image_tensor: np.expand_dims(x, axis=0)})
        logging.info('prediction made')
        # Format data
        data = {'prediction':inv_mapping[pred.argmax(axis=1)[0]],'confidence':'Normal: {:.3f}, Pneumonia: {:.3f}, COVID-19: {:.3f}'.format(pred[0][0], pred[0][1], pred[0][2])}
        logging.info(data)

        return data

# Extract data from incoming event
def extract_data(msg):
    return msg['data']

# Run this when a new event has been received
def run_event(event):
    logging.info(event.Data())

    # Retrieve info from notification
    extracted_data = extract_data(event.Data())
    uid = extracted_data['uid']
    img_key = extracted_data['image_name']
    logging.info('Analyzing: ' + img_key + ' for uid: ' + uid)


    # Make prediction
    result = model.prediction(img_key)

    logging.info('result=' + result['prediction'])

# Load model
model = Model(meta_url,ckpt_url)
logging.info('model initialized')

# Start event listener
client = CloudeventsServer()
client.start_receiver(run_event)
#result = model.prediction('1-s2.0-S0140673620303706-fx1_lrg.jpg')

Что работает:

  • При запуске модель загружается,
  • Слушатель / анализатор событий работает, я могу получить входящие значения,
  • Если я делаю прогноз (последняя строка, закомментированный), не запуская слушатель (или не делаю его до запуска слушателя), он работает. Доказательство того, что модель хорошо загружена (да!).

Что не работает:

  • При попытке вызвать экземпляр модели из функции, вызванной слушателем ( который работает, когда он добирается до функции прогнозирования), я получаю:
terminate called after throwing an instance of 'std::system_error'
  what():  No such process

Так что кажется, что экземпляр Model есть (потому что функция прогнозирования запущена), но я вроде как теряю сеанс tf в слушателе-> run_event-> конвейер прогноза ... По крайней мере, это то, что я понимаю.

Чего мне не хватает? Есть ли лучший способ сделать это?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...