Итак, у меня есть этот проект, в котором я должен:
- обучить нейронную сеть. Я нашел этот проект для распознавания рукописного текста (ссылка на репозиторий github: HTR Network ), который имеет три разные архитектуры. Я выбрал «puigcerver»
- , чтобы преобразовать его в модель tflite
- , загрузить в приложение Android и получить вывод
Первые два пункта прошли хорошо, но последний застал меня врасплох. Я могу получить вывод (3D-тензор - форма: [1] [128] [98]), но я не знаю, как его декодировать.
У меня есть две основные проблемы:
- Вывод модели tflite - это 3D-тензор с плавающей запятой, где каждый массив [N] 1D из 98 значений должен представлять вероятность каждого из символов в кодировка для N символов из 128 символов предложения. В этой статье автор заявляет, что кодировка состоит из 95 символов: Article . Итак, первая проблема заключается в том, что у меня есть 3 значения (для каждого символа предложения), которые я не ожидал получить
- Все значения этого трехмерного тензора очень малы (т. Е. 2,15..E-24 и меньше), за исключением последнего значения (98-го значения), которое составляет около 0,98 / 0,99, и некоторых других значений, которые составляют около 0,002 / 0,004 / 0,008. Если я обращаюсь к ним как к вероятностям, ища более высокое значение (исключая 96-е, 97-е, 98-е значение), я получаю предложения типа «LLLLLqqqqqgggggggggoo ... pppp ...» (явно неверно)
I Я попытался вывести одно и то же изображение с исходной сетью (опция --image), и результат был в порядке, поэтому я подумал, что, возможно, я делаю некоторые ошибки при загрузке модели tflite или изображения. Я также подумал, что, возможно, мне нужно выполнить лучевой (или жадный) поиск вместо того, чтобы просто искать более высокое значение.
Итак, мой вопрос, действительно ли вывод представляет собой тензор с вероятностями или я что-то упустил? Как правильно декодировать вывод?
Преобразование TFlite:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
from tensorflow.keras.models import load_model
from network.model import HTRModel
keras.backend.clear_session()
model = load_model("checkpoint_weights.hdf5",custom_objects={'ctc_loss_lambda_func':HTRModel.ctc_loss_lambda_func}, compile=False)
model.compile(loss=HTRModel.ctc_loss_lambda_func)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.__setattr__('allow_custom_ops',True)
tflite = converter.convert()
open("tflite_model.tflite","wb").write(tflite)
Загрузка модели и изображения (в приложении Android - Kotlin)
...
@Throws(IOException::class)
fun initializeInterpreter() {
// Load the TF Lite model
val assetManager = context.assets
println("Loading model file...")
val model = loadModelFile(assetManager)
println("Inizializing TF Lite interpreter...")
// Initialize TF Lite Interpreter (with NNAPI enabled)
val options = Interpreter.Options()
//options.setUseNNAPI(true)
val interpreter = Interpreter(model, options)
// Read input shape from model file
println("Reading input shape from model...")
val inputShape = interpreter.getInputTensor(0).shape()
inputImageWidth = inputShape[1]
println("Shape 1: " + inputShape[1])
inputImageHeight = inputShape[2]
println("Shape 2: " + inputShape[2])
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
println("Total input size: $modelInputSize")
// Finish interpreter initialization
this.interpreter = interpreter
isInitialized = true
Log.d(TAG, "Initialized TFLite interpreter.")
}
@Throws(IOException::class)
private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
val fileDescriptor = assetManager.openFd(MODEL_FILE)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
fun classify(bitmap: Bitmap): String{
if (!isInitialized) {
throw IllegalStateException("TF Lite Interpreter is not initialized yet.")
}
/*var startTime: Long
var elapsedTime: Long*/
// Preprocessing: resize the input
println("Preprocessing - Resizing the input..")
//startTime = System.nanoTime()
val resizedImage = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true)
val byteBuffer = convertBitmapToByteBuffer(resizedImage)
//elapsedTime = (System.nanoTime() - startTime) / 1000000
//Log.d(TAG, "Preprocessing time = " + elapsedTime + "ms")
//startTime = System.nanoTime()
println("Running interpreter...")
val result = Array(1) { Array(128) { FloatArray(98)} }
interpreter?.run(byteBuffer, result)
//elapsedTime = (System.nanoTime() - startTime) / 1000000
//Log.d(TAG, "Inference time = " + elapsedTime + "ms")*/
println("Result: $result")
predicted = result
val r = result[0]
var i : Int
var out = "["
for (i in 0..127)
{
out += getOutputString(r[i])
}
out += "]"
//return out
return decodeSentece()
}
fun classifyAsync(bitmap: Bitmap): Task<String> {
return call(executorService, Callable<String> { classify(bitmap) })
}
fun close() {
call(
executorService,
Callable<String> {
interpreter?.close()
Log.d(TAG, "Closed TFLite interpreter.")
null
}
)
}
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
byteBuffer.order(ByteOrder.nativeOrder())
val pixels = IntArray(inputImageWidth * inputImageHeight)
bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
for (pixelValue in pixels) {
val r = (pixelValue shr 16 and 0xFF)
val g = (pixelValue shr 8 and 0xFF)
val b = (pixelValue and 0xFF)
// Convert RGB to grayscale and normalize pixel value to [0..1]
val normalizedPixelValue = (r + g + b) / 3.0f / 255.0f
byteBuffer.putFloat(normalizedPixelValue)
}
return byteBuffer
}
Это (просто для справки) первые три строки выходного тензора, который я получаю:
[[ 1.02767E-9 , 1.0279699E-9 , 3.139572E-4 , 2.1537318E-4 , 0.002740169 , 3.247818E-5 , 1.0521399E-5 , 2.9409595E-4 , 3.4732078E-4 , 1.8211146E-6 , 1.8896988E-5 , 1.0662472E-5 , 0.30825683 , 1.9493811E-5 , 0.15347835 , 7.209303E-5 , 0.10027219 , 1.7075523E-5 , 1.2530926E-5 , 0.011397267 , 0.004646272 , 2.7380593E-6 , 1.21792706E-4 , 0.0030580924 , 0.0037772718 , 0.01746696 , 7.144082E-4 , 2.8029652E-5 , 5.0955594E-5 , 0.030346025 , 6.488925E-4 , 4.2334714E-4 , 0.0010043866 , 2.5545945E-4 , 0.0057575954 , 9.177484E-4 , 8.877261E-6 , 8.9535155E-5 , 1.8161612E-4 , 9.084007E-5 , 0.16994531 , 5.3004638E-5 , 0.0055839983 , 2.0921101E-4 , 5.918604E-4 , 1.8135815E-4 , 1.8600904E-4 , 1.3010694E-5 , 7.701663E-5 , 0.011575605 , 0.001689526 , 7.148706E-5 , 4.6759746E-5 , 4.6543145E-4 , 8.721436E-7 , 0.007293349 , 1.371512E-4 , 7.3025643E-4 , 1.2985134E-4 , 2.2911412E-5 , 5.126359E-4 , 6.7653934E-7 , 8.5107595E-6 , 2.8127617E-6 , 1.8223985E-6 , 0.0017414617 , 5.9720107E-5 , 1.0222625E-9 , 1.095164E-9 , 4.0091674E-8 , 6.4829896E-5 , 0.0028546597 , 9.027818E-7 , 1.0092357E-9 , 1.0142923E-8 , 2.4069648E-5 , 0.020749098 , 3.682649E-4 , 2.9179654E-7 , 2.4412808E-5 , 3.0341505E-6 , 1.0448614E-9 , 9.900646E-10 , 9.750201E-10 , 3.0151805E-5 , 1.0438933E-9 , 1.0359494E-9 , 1.0682695E-9 , 1.0451963E-9 , 1.0827725E-9 , 1.0221299E-9 , 1.019132E-9 , 9.502418E-10 , 9.924109E-10 , 1.0422022E-9 , 1.0894825E-9 , 0.019479617 , 0.107967034]
[ 4.1440778E-15 , 4.1381066E-15 , 2.8557254E-11 , 1.8157608E-14 , 8.410285E-15 , 2.9407989E-16 , 1.060987E-16 , 2.4269218E-14 , 7.374521E-16 , 1.5804724E-18 , 5.3503064E-16 , 3.1957313E-14 , 3.8080086E-6 , 3.4774015E-13 , 9.287003E-7 , 6.938215E-10 , 1.06444844E-4 , 1.3310229E-11 , 1.4083793E-11 , 6.8555856E-8 , 8.523E-8 , 1.7820193E-15 , 8.969676E-11 , 1.0272463E-8 , 3.4963513E-8 , 7.1322137E-7 , 1.10203715E-7 , 3.426864E-11 , 2.6655234E-14 , 2.8331244E-5 , 6.8602155E-8 , 2.5318696E-9 , 2.9697837E-8 , 2.6458968E-9 , 3.169576E-9 , 9.308899E-10 , 1.7746693E-11 , 2.1560531E-16 , 3.1369616E-12 , 2.1284496E-15 , 8.978029E-11 , 3.51387E-14 , 1.712866E-11 , 3.6545718E-14 , 2.1302052E-14 , 1.5469606E-12 , 1.8957156E-12 , 5.3142797E-17 , 3.2412155E-16 , 3.797056E-12 , 8.453629E-11 , 5.663211E-12 , 1.2467538E-14 , 1.8673284E-13 , 1.1108751E-15 , 1.6332011E-10 , 1.0867543E-13 , 1.4709664E-11 , 1.1832392E-13 , 3.5229395E-16 , 2.2621423E-12 , 1.4893199E-14 , 1.1778153E-16 , 4.108518E-16 , 2.5608845E-16 , 2.2514338E-11 , 6.590058E-16 , 4.1018955E-15 , 4.210221E-15 , 4.0438975E-15 , 1.7089807E-13 , 1.7934002E-16 , 3.64835E-19 , 4.5916805E-15 , 7.155464E-15 , 2.0624084E-12 , 5.2085922E-9 , 9.621887E-10 , 2.6019372E-16 , 1.3379034E-14 , 5.9298675E-16 , 4.35525E-15 , 4.027287E-15 , 4.065149E-15 , 3.1142788E-15 , 4.126221E-15 , 4.1317653E-15 , 4.5336524E-15 , 4.308452E-15 , 4.2372895E-15 , 3.9838916E-15 , 4.3225607E-15 , 3.967846E-15 , 4.1830697E-15 , 4.086403E-15 , 4.2691384E-15 , 0.003423732 , 0.99643564]
[ 1.1704519E-17 , 1.1569569E-17 , 5.2640887E-15 , 2.9042624E-19 , 4.888624E-22 , 8.800831E-24 , 2.2810482E-22 , 1.0456246E-20 , 4.613259E-23 , 2.9788603E-24 , 4.2340931E-22 , 5.365159E-19 , 1.9403133E-8 , 3.617449E-18 , 8.07217E-9 , 3.6707356E-12 , 6.9733637E-6 , 1.00665275E-14 , 1.2904081E-13 , 2.300352E-10 , 3.3837338E-10 , 2.1202834E-21 , 4.4093415E-14 , 2.4844083E-11 , 1.8735478E-9 , 2.7742542E-8 , 1.5738684E-9 , 2.9336394E-13 , 9.488479E-19 , 2.5946217E-6 , 2.9507443E-9 , 1.29704225E-11 , 1.9227062E-9 , 1.7503684E-11 , 4.312201E-12 , 1.6009424E-12 , 2.0033327E-13 , 1.4647874E-23 , 4.3650106E-17 , 2.8878068E-23 , 1.3077074E-17 , 1.0820965E-21 , 1.6695615E-17 , 3.3014689E-21 , 8.4159397E-22 , 5.0169094E-19 , 1.9141252E-17 , 5.20164E-26 , 2.0089118E-24 , 2.2912117E-19 , 1.5849434E-15 , 4.2030135E-17 , 2.2811363E-21 , 2.2500594E-20 , 7.502992E-21 , 7.9013944E-16 , 5.3322657E-20 , 4.289789E-16 , 2.5743777E-20 , 1.5604532E-24 , 1.4845246E-18 , 2.1423116E-18 , 1.1037277E-24 , 1.1757478E-21 , 1.5670599E-20 , 7.630421E-15 , 9.0307036E-23 , 1.2181271E-17 , 1.18615665E-17 , 8.01336E-19 , 4.587686E-17 , 2.1303416E-24 , 4.3491843E-26 , 1.3229859E-17 , 1.1444598E-17 , 2.6614466E-13 , 3.7416266E-9 , 2.5167621E-9 , 3.0683186E-20 , 5.911551E-17 , 1.07436035E-20 , 1.1928041E-17 , 1.2051209E-17 , 1.1595416E-17 , 1.8123445E-20 , 1.1205827E-17 , 1.1453464E-17 , 1.3506912E-17 , 1.2615066E-17 , 1.1414165E-17 , 1.01463546E-17 , 1.2532566E-17 , 1.1185157E-17 , 1.1822221E-17 , 1.1537131E-17 , 1.184172E-17 , 0.1587565 , 0.84123385 ]
...
]
Извините за этот глупый вопрос, но я новичок в Tensorflow и Tensorflow Lite ... Заранее спасибо