Некоторые важные моменты: 1. Работа над собственным приложением реагирования 2. Модель FLOAT32 tflite 3. Передача изображения в градациях серого
Для каждого проходимого изображения я получаю что-то вроде этого: [{"label": "4", "prob": "1.0"}, {"label": "0", "prob": "1.258652E-8"}, {"label": "3", "prob": "2.6068769E-15"}, {"label": "2", "prob": "8.913677E-17"}, {"label": "1", "prob": "1.986853E-19"}]
Метка «4» всегда равна 1,0 или очень близко к 1,0 для каждого изображения.
Иногда выигрывает ярлык «2», но очень и очень редко. [{"label": "2", "prob": "0.9999604"}, {"label": "0", "prob": "3.142011E-5"}, {"label": "4", "prob": "7.964015E-6"}, {"label": "3", "prob": "1.0607963E-7"}, {"label": "1", "prob": "5.5956546E-9"}]
С другой моделью UINT8 tflite у меня была строка, которая выглядит следующим образом: float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
Для модели FLOAT32 я ничего не делаю, так что, возможно, именно здесь я ' что-то пошло не так?
Вот некоторый соответствующий код, который делает тяжелый анализ:
public class KerasModelModule extends ReactContextBaseJavaModule {
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 1;
static final int DIM_IMG_SIZE_X = 300;
static final int DIM_IMG_SIZE_Y = 100;
private static final int BYTE_SIZE_OF_FLOAT = 4;
private final String TAG = this.getClass().getSimpleName();
protected ByteBuffer imgData = ByteBuffer.allocateDirect(BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
private Interpreter tflite;
private List<String> labelList;
private float[][] labelProbArray = null;
private static final int RESULTS_TO_SHOW = 5;
private static final float IMAGE_MEAN = 0f;
private static final float IMAGE_STD = 255f;
private static final String MODEL_PATH = "best-model-03242020_model.tflite";
private static final String LABEL_PATH = "best-model-03242020_dict.txt";
private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
public KerasModelModule(ReactApplicationContext reactContext) {
super(reactContext);
}
@Override
public String getName() {
return "KerasModel";
}
@ReactMethod
void predict(String base64Image, final Promise promise) {
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new float[1][RESULTS_TO_SHOW];
try {
labelList = loadLabelList();
} catch (Exception ex) {
ex.printStackTrace();
}
byte[] decodedString = Base64.decode(base64Image, Base64.DEFAULT);
Bitmap old_bitmap = BitmapFactory.decodeByteArray(decodedString, 0, decodedString.length);
Bitmap bitmap = Bitmap.createScaledBitmap(old_bitmap, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, true);
convertBitmapToByteBuffer(bitmap);
try {
tflite = new Interpreter(loadModelFile());
} catch (Exception ex) {
ex.printStackTrace();
Log.w("FIND_exception", "1");
}
tflite.run(imgData, labelProbArray);
promise.resolve(getResult());
}
private WritableNativeArray getResult() {
WritableNativeArray result = new WritableNativeArray();
try {
for (int i = 0; i < RESULTS_TO_SHOW; i++) {
WritableNativeMap map = new WritableNativeMap();
map.putString("label", labelList.get(i));
// float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
map.putString("prob", String.valueOf(labelProbArray[0][i]));
result.pushMap(map);
}
} catch (Exception ex) {
ex.printStackTrace();
}
return result;
}
private List<String> loadLabelList() throws IOException {
Activity activity = getCurrentActivity();
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
private void convertBitmapToByteBuffer(Bitmap bitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// Convert the image to floating point.
int pixel = 0;
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
float rChannel = (val >> 16) & 0xFF;
float gChannel = (val >> 8) & 0xFF;
float bChannel = (val) & 0xFF;
float gray = rChannel * 0.299f + gChannel * 0.587f + bChannel * 0.114f;
imgData.putFloat(gray);
}
}
}
private MappedByteBuffer loadModelFile() throws IOException {
Activity activity = getCurrentActivity();
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}