Керас и тензор потока (ошибка бэкэнда) Тензор conv2d_1_input: 0, указанный в feed_devices или fetch_devices, не найден в графике - PullRequest
0 голосов
/ 19 марта 2020

Я использовал керасы и тензорфоу, и я совершенно новичок в этом. Я обучил свои модели, и когда я делаю, чтобы предсказать это, ошибка показывает. Это код, который я использовал для предсказания изображения

import numpy as np
from flask import Flask, request, jsonify, render_template
import numpy
from PIL import Image
import os
import tensorflow.keras
from werkzeug.utils import secure_filename
from keras.models import load_model

app = Flask(__name__)

model = load_model('traffic_classifier.h5')
model._make_predict_function()

@app.route('/')
def index():
    # Main page
    return render_template('index.html')

@app.route('/traffic')
def traffic():
    # Main page
    return render_template('traffic.html')

@app.route('/sleep')
def sleep():
    # Main page
    return render_template('sleep.html')

@app.route('/predict',methods=['POST'])
def predict():
    '''
    For rendering results on HTML GUI
    '''



    classes = { 1:'Speed limit (20km/h)',
            2:'Speed limit (30km/h)',      
            3:'Speed limit (50km/h)',       
            4:'Speed limit (60km/h)',      
            5:'Speed limit (70km/h)',    
            6:'Speed limit (80km/h)',      
            7:'End of speed limit (80km/h)',     
            8:'Speed limit (100km/h)',    
            9:'Speed limit (120km/h)',     
           10:'No passing',   
           11:'No passing veh over 3.5 tons',     
           12:'Right-of-way at intersection',     
           13:'Priority road',    
           14:'Yield',     
           15:'Stop',       
           16:'No vehicles',       
           17:'Veh > 3.5 tons prohibited',       
           18:'No entry',       
           19:'General caution',     
           20:'Dangerous curve left',      
           21:'Dangerous curve right',   
           22:'Double curve',      
           23:'Bumpy road',     
           24:'Slippery road',       
           25:'Road narrows on the right',  
           26:'Road work',    
           27:'Traffic signals',      
           28:'Pedestrians',     
           29:'Children crossing',     
           30:'Bicycles crossing',       
           31:'Beware of ice/snow',
           32:'Wild animals crossing',      
           33:'End speed + passing limits',      
           34:'Turn right ahead',     
           35:'Turn left ahead',       
           36:'Ahead only',      
           37:'Go straight or right',      
           38:'Go straight or left',      
           39:'Keep right',     
           40:'Keep left',      
           41:'Roundabout mandatory',     
           42:'End of no passing',      
           43:'End no passing veh > 3.5 tons' }




    if request. method == "POST":
        #image=request. form["fileupload"]

        f = request.files['file']

        # Save the file to ./uploads
        basepath = os.path.dirname(__file__)
        file_path = os.path.join(
            basepath, 'uploads', secure_filename(f.filename))
        f.save(file_path)  


    image = Image.open(file_path)
    image = image.resize((30,30))
    image = numpy.expand_dims(image, axis=0)
    image = numpy.array(image)

    pred = model.predict_classes([image])[0]

    sign = classes[pred+1]





    return render_template('traffic.html', prediction_text='This sign represents {}'.format(sign))


if __name__ == "__main__":
    app.run(debug=True)

Я получаю ошибку

tenorflow. python .framework.errors_impl.InvalidArgumentError tenorflow. python .framework.errors_impl.InvalidArgumentError: Tensor conv2d_1_input: 0, указанный в либо feed_devices, либо fetch_devices не найдены в Графике

что с этим делать ??

Ответы [ 2 ]

0 голосов
/ 23 марта 2020

Проблема в том, что Flask использует потоки. Это означает, что для каждого запроса Flask создает новый поток. Таким образом, ваша модель не видна из запроса.

Чтобы решить эту проблему, необходимо сделать модель частью глобального сеанса, который используется повсеместно.

Решение можно найти здесь как ошибка .

from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model

tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()

# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras! 
# Otherwise, their weights will be unavailable in the threads after the session there has been set
set_session(sess)
model = load_model(...)

тогда внутри вашего метода:

def predict():
    ....
    global sess
    global graph
    with graph.as_default():
    set_session(sess)
    pred = model.predict_classes(...)
    ...
0 голосов
/ 19 марта 2020

Решил, добавив эти коды

config = tensorflow.ConfigProto(
    device_count={'GPU': 1},
    intra_op_parallelism_threads=1,
    allow_soft_placement=True
)

config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.6

session = tensorflow.Session(config=config)
keras.backend.set_session(session)

model = load_model('traffic_classifier.h5')
model._make_predict_function()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...