Итак, я пытаюсь создать модель классификации собак и кошек с использованием Keras. Частью моей цели является создание веб-сайта, который развертывает модель с использованием Tensorflow.js. Я успешно развернул модель, используя Flask в качестве сервера.
Основная проблема заключается в том, что модель Tensorflow.js работает намного хуже, чем модель в простых кератах. При использовании обычных керас моя модель достигла около 90% точности данных испытаний. Однако при использовании в tenorflow.js модель не получила ни одного из тестовых изображений правильно. Буду признателен за любую помощь или советы по устранению этой проблемы.
templates / index.html
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<title>repl.it</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<link href="{{ url_for('static', filename='index.css') }}" rel="stylesheet" type="text/css" />
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
</head>
<body onload="$('#result').hide();$('#continue').hide();">
<div class="container-fluid">
<!-- START HEADER -->
<div class="row" id="headerRow">
<div class="col-md d-flex align-items-center" id="headerColumn">
<h2>Cat<span class='or'>or</span>Dog</h2>
</div>
</div>
<!-- END HEADER -->
<!-- START BODY -->
<div class="row bodyRow" id='bodyRow'>
<div class="col-md d-flex align-items-center bodyColumn">
<div class="body">
<form class="d-flex align-items-center justify-content-center imageSubmitForm" method="POST" enctype="multipart/form-data">
<label class="d-flex align-items-center justify-content-center" for='imageInputField'>
<i class="material-icons">file_upload</i>
<p id='result'></p>
<br/>
<p id='continue'>Press Anywhere to continue...</p>
</label>
<input class="imageInputField" id='imageInputField' type='file' onchange='getPrediction(url)'/>
</form>
</div>
</div>
</div>
<!-- END BODY -->
<!-- START RESULT -->
<div class="row resultRow">
<div class="col-md-6 classResultColumn">
<div class="d-flex align-items-center justify-content-center classResultBox">
<p id='classResult'></p>
</div>
</div>
<div class="col-md-6 scoreResultColumn">
<div class="d-flex align-items-center justify-content-center scoreResultBox">
<p id='scoreResult'></p>
</div>
</div>
</div>
<!-- END RESULT -->
<!-- START FOOTER -->
<!--
<div class="row d-flex align-items-center footerRow" id='footerRow'>
<center><a src="#">Source Code</a></center>
</div>
-->
<!-- FOOTER -->
</div>
<!-- START SCRIPTS -->
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js"></script>
<script src="https://code.jquery.com/jquery-3.4.1.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
<script src="{{ url_for('static', filename='index.js') }}"></script>
<!-- END SCRIPTS -->
</body>
</html>
static / index.js
let fileInput = document.getElementById("imageInputField");
let classResultElement = document.getElementById("classResult");
let scoreResultElement = document.getElementById("scoreResult");
let url = "/model";
let model;
let file;
let data;
let responseContent;
let features;
let predictedClass;
let getPrediction = async(url) => {
if (!model)
model = await tf.loadLayersModel(url);
file = fileInput.files[0];
data = new FormData();
data.append("file", file);
$.ajax({
url : "/api/preprocess",
type: 'POST',
data: data,
traditional: true,
processData: false,
contentType: false,
success: function(response)
{
responseContent = JSON.parse(response)['image'];
if (responseContent != "False")
{
features = tf.tensor(responseContent);
score = model.predict(features).dataSync();
alert(score);
if (score >= 0.5) {
predictedClass = "Dog";
classResultElement.innerHTML = "<b>Predicted Class:</b> " + predictedClass;
scoreResultElement.innerHTML = "<b>Certainty:</b> " + score*100.0 + "%";
} else {
predictedClass = "Cat";
classResultElement.innerHTML = "<b>Predicted Class:</b>" + predictedClass;
scoreResultElement.innerHTML = "<b>Certainty:</b> " + (1.0 - score) * 100.0 + "%";
}
alert(predictedClass);
}
}
});
}
app.py
import flask
from flask_cors import CORS
from werkzeug import secure_filename
import time
import os
import keras
import numpy as np
import json
import matplotlib.pyplot as plt
app = flask.Flask(__name__)
CORS(app)
UPLOADS_DIR = "uploads/"
@app.route("/")
def index():
"""
Fetch and return the main homepage.
"""
return flask.render_template("index.html")
@app.route("/favicon.ico")
def get_favicon():
"""
Return a fake message in order to silence the error caused by a favicon not being found.
"""
return "Favicon Does Not Exist"
@app.route("/model")
def get_modeljson():
"""
Get the model.json file and return it's contents.
"""
with open("model/model.json", "r") as f:
return f.read()
@app.route("/<path:path>")
def get_shard(path):
"""
get the binary weight file for the model (also known as a shard).
path => the filename of the binary weight file.
"""
return flask.send_from_directory("model/", path)
@app.route("/api/preprocess", methods=['POST'])
def preprocess():
"""
takes an image object from an AJAX request and returns a normalized list of the values.
"""
if flask.request.method == 'POST':
file = flask.request.files['file']
filename = secure_filename(file.filename)
new_filename = "{}_{}".format(time.time(), filename)
file.save(os.path.join(UPLOADS_DIR, new_filename))
img_obj = keras.preprocessing.image.load_img(os.path.join(UPLOADS_DIR, new_filename), target_size=(224, 224))
img_arr = keras.preprocessing.image.img_to_array(img_obj).reshape(1, 224, 224, 3)
img_arr = np.divide(img_arr, 255.)
os.remove(os.path.join(UPLOADS_DIR, new_filename))
return json.dumps({"image":img_arr.tolist()})
return json.dumps({"image":"False"})
if __name__ == "__main__":
app.run()
Здесь вы можете найти URL-адрес записной книжки, используемой для обучения модели , здесь . Вы можете найти блокнот, использованный для проверки кода здесь .
Любая помощь или советы с благодарностью.