Я тренирую модель, используя Google AutoML. Затем я экспортировал модель TFLite для запуска в приложении android.
С одним и тем же изображением, используемым в приложении и в браузере (Google AutoML дает вам возможность протестировать вашу модель в браузере), я получаю очень точный ответ в браузере, но в своем приложении я получаю не очень точный ответ.
Есть мысли о том, почему это может происходить?
Соответствующий код:
public void predict(String imagePath, final Promise promise) {
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new byte[1][RESULTS_TO_SHOW];
try {
labelList = loadLabelList();
} catch (Exception ex) {
ex.printStackTrace();
}
Log.w("FIND_ ", imagePath);
Bitmap bitmap = BitmapFactory.decodeFile(imagePath);
convertBitmapToByteBuffer(bitmap);
try {
tflite = new Interpreter(loadModelFile());
} catch (Exception ex) {
Log.w("FIND_exception in loading tflite", "1");
}
tflite.run(imgData, labelProbArray);
promise.resolve(getResult());
}
private WritableNativeArray getResult() {
WritableNativeArray result = new WritableNativeArray();
//ArrayList<JSONObject> result = new ArrayList<JSONObject>();
try {
for (int i = 0; i < RESULTS_TO_SHOW; i++) {
WritableNativeMap map = new WritableNativeMap();
map.putString("label", Integer.toString(i));
float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
map.putString("prob", String.valueOf(output));
result.pushMap(map);
Log.w("FIND_label ", Integer.toString(i));
Log.w("FIND_prob ", String.valueOf(labelProbArray[0][i]));
}
} catch (Exception ex) {
ex.printStackTrace();
}
return result;
}
private List<String> loadLabelList() throws IOException {
Activity activity = getCurrentActivity();
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
private void convertBitmapToByteBuffer(Bitmap bitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
Log.w("FIND_bitmap width ", String.valueOf(bitmap.getWidth()));
Log.w("FIND_bitmap height ", String.valueOf(bitmap.getHeight()));
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// Convert the image to floating point.
int pixel = 0;
//long startTime = SystemClock.uptimeMillis();
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
imgData.put((byte) ( ((val >> 16) & 0xFF)));
imgData.put((byte) ( ((val >> 8) & 0xFF)));
imgData.put((byte) ( (val & 0xFF)));
Log.w("FIND_i", String.valueOf(i));
Log.w("FIND_j", String.valueOf(j));
}
}
}
private MappedByteBuffer loadModelFile() throws IOException {
Activity activity = getCurrentActivity();
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
Одна из вещей, которые я не знаю, если это правильно, это как я корректирую свои выходные данные в вероятности. В методе getResult()
я пишу это:
float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
Я делю на 255, чтобы получить число от 0 до 1, но я не знаю, если это правильный.
Модель .tflite представляет собой квантованную модель uint8, если это помогает.