Держите модель инициализированной в колбе - PullRequest
1 голос
/ 24 сентября 2019

У меня есть проблема, которую я чувствую, легко исправить, но я понятия не имею, я довольно новичок в колбе.

Я хочу иметь колбу, которая позволяет пользователю загружать изображение и тестировать его на специально обученной модели керас для обнаружения пород кошек.

Код уже работает для предварительного прогнозирования без использования колбы, и я могу заставить его работать, если модель инициализируется непосредственно перед запуском предикации.

Моя цель заключается в следующем: как заставить модель оставаться инициализированной, чтобы ее можно было запускать любое количество раз без необходимости повторной инициализации каждый раз?

Вот следующеекод:

import os
#import magic
import urllib.request
from app import app
from flask import Flask, flash, request, redirect, render_template
from werkzeug.utils import secure_filename
import tensorflow as tf
import numpy as np
import requests
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.models import load_model
from tensorflow.keras.applications import xception
from PIL import Image

ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])

class CatClass:
    def __init__(self):
        from tensorflow.python.lib.io import file_io
        model_file = file_io.FileIO('gs://modelstorageforcats/94Kitty.h5', mode='rb')

        temp_model_location = './94Kitty.h5'
        temp_model_file = open(temp_model_location, 'wb')
        try:
            temp_model_file.write(model_file.read())
            temp_model_file.close()
            model_file.close()
        except:
            raise("Issues with getting the model")
        # get model
        self.catModel = load_model(temp_model_location)

    def catEstimator(self, catImage):
        script_dir = os.path.dirname(__file__) #<-- absolute dir the script is in
        rel_path = "uploads/" + catImage
        abs_file_pathk = os.path.join(script_dir, rel_path)

        #get picture the proper size for xception
        try:
            kittyPic = image.load_img(abs_file_pathk, target_size=(299,299))
            x = xception.preprocess_input(np.expand_dims(kittyPic.copy(), axis=0))
        except:
            raise("Error with the images")

        #cat names the way the model learned it
        catNames = ["Bengal","Abyssinian","BritishShorthair","Birman","Sphynx","Bombay","EgyptianMau","Persian","Ragdoll","MaineCoon","Siamese","RussianBlue","AmericanBobtail","DevonRex","AmericanCurl","DonSphynx","Manx","Balinese","Burmilla","Burmese","KhaoManee","Chausie","AmericanShortHair","Chartreux","Pixiebob","JapaneseBobtail","BritishLonghair","CornishRex","Tabby","Somali","ExoticShortHair","Tonkinese","OrientalShortHair","Minskin","Korat","Savannah","Havana","Singapura","Nebelung","OrientalLonghair","TurkishAngora","ScottishFold","KurilianBobtail","Lykoi","ScottishFoldLonghair","Ocicat","Munchkin","SelkirkRex","AustralianMist","AmericanWireHair","TurkishVan","SnowShoe","Peterbald","Siberian","Toybob","Himalayan","LePerm","NorwegianForestCat"]

        prediction = (self.catModel.predict(x))

        label = int(np.argmax(prediction, axis=-1))

        return(catNames[label])

catter = CatClass()


def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/')
def upload_form():
    return render_template('upload.html')

@app.route('/', methods=['POST'])
def upload_file():
    if request.method == 'POST':
        # check if the post request has the file part
        if 'file' not in request.files:
            flash('No file part')
            return redirect(request.url)
        file = request.files['file']
        if file.filename == '':
            flash('No file selected for uploading')
            return redirect(request.url)
        if file and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
            flash('File successfully uploaded')
            flash(catter.catEstimator(filename))
            return redirect('/')
        else:
            flash('Allowed file types are txt, pdf, png, jpg, jpeg, gif')
            return redirect(request.url)

if __name__ == "__main__":
    app.run()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...