Модель Keras (.h5), преобразованная в модель Tflite, всегда равна 0 вероятности, за исключением одной метки, которая равна 1 в Android. Почему это происходит? - PullRequest
0 голосов
/ 26 марта 2020

Некоторые важные моменты: 1. Работа над собственным приложением реагирования 2. Модель FLOAT32 tflite 3. Передача изображения в градациях серого

Для каждого проходимого изображения я получаю что-то вроде этого: [{"label": "4", "prob": "1.0"}, {"label": "0", "prob": "1.258652E-8"}, {"label": "3", "prob": "2.6068769E-15"}, {"label": "2", "prob": "8.913677E-17"}, {"label": "1", "prob": "1.986853E-19"}]

Метка «4» всегда равна 1,0 или очень близко к 1,0 для каждого изображения.

Иногда выигрывает ярлык «2», но очень и очень редко. [{"label": "2", "prob": "0.9999604"}, {"label": "0", "prob": "3.142011E-5"}, {"label": "4", "prob": "7.964015E-6"}, {"label": "3", "prob": "1.0607963E-7"}, {"label": "1", "prob": "5.5956546E-9"}]

С другой моделью UINT8 tflite у меня была строка, которая выглядит следующим образом: float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;

Для модели FLOAT32 я ничего не делаю, так что, возможно, именно здесь я ' что-то пошло не так?

Вот некоторый соответствующий код, который делает тяжелый анализ:

public class KerasModelModule extends ReactContextBaseJavaModule {

  private static final int DIM_BATCH_SIZE = 1;
  private static final int DIM_PIXEL_SIZE = 1;
  static final int DIM_IMG_SIZE_X = 300;
  static final int DIM_IMG_SIZE_Y = 100;
  private static final int BYTE_SIZE_OF_FLOAT = 4;

  private final String TAG = this.getClass().getSimpleName();
  protected ByteBuffer imgData = ByteBuffer.allocateDirect(BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);

  private Interpreter tflite;
  private List<String> labelList;
  private float[][] labelProbArray = null;
  private static final int RESULTS_TO_SHOW = 5;
  private static final float IMAGE_MEAN = 0f;
  private static final float IMAGE_STD = 255f;

  private static final String MODEL_PATH = "best-model-03242020_model.tflite";
  private static final String LABEL_PATH = "best-model-03242020_dict.txt";
  private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];

  public KerasModelModule(ReactApplicationContext reactContext) {
    super(reactContext);
  }

  @Override
  public String getName() {
    return "KerasModel";
  }

  @ReactMethod 
  void predict(String base64Image, final Promise promise) {
    imgData.order(ByteOrder.nativeOrder());
    labelProbArray = new float[1][RESULTS_TO_SHOW];

    try {
      labelList = loadLabelList();
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    byte[] decodedString = Base64.decode(base64Image, Base64.DEFAULT);
    Bitmap old_bitmap = BitmapFactory.decodeByteArray(decodedString, 0, decodedString.length);
    Bitmap bitmap = Bitmap.createScaledBitmap(old_bitmap, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, true);
    convertBitmapToByteBuffer(bitmap);


    try {
      tflite = new Interpreter(loadModelFile());
    } catch (Exception ex) {
      ex.printStackTrace();
      Log.w("FIND_exception", "1");
    }

    tflite.run(imgData, labelProbArray);
    promise.resolve(getResult());

  }


  private WritableNativeArray getResult() {
    WritableNativeArray result = new WritableNativeArray();

    try {

      for (int i = 0; i < RESULTS_TO_SHOW; i++) {
          WritableNativeMap map = new WritableNativeMap();

          map.putString("label", labelList.get(i));
          // float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
          map.putString("prob", String.valueOf(labelProbArray[0][i]));
          result.pushMap(map);

      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    return result;
  }

  private List<String> loadLabelList() throws IOException {
    Activity activity = getCurrentActivity();
      List<String> labelList = new ArrayList<String>();
      BufferedReader reader =
              new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
      String line;
      while ((line = reader.readLine()) != null) {
          labelList.add(line);
      }
      reader.close();
      return labelList;
  }

  private void convertBitmapToByteBuffer(Bitmap bitmap) {
      if (imgData == null) {
          return;
      }
      imgData.rewind();

      bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
      // Convert the image to floating point.
      int pixel = 0;
      for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
          for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
              final int val = intValues[pixel++];
              float rChannel = (val >> 16) & 0xFF;
              float gChannel = (val >> 8) & 0xFF;
              float bChannel = (val) & 0xFF;
              float gray = rChannel * 0.299f + gChannel * 0.587f + bChannel * 0.114f;
              imgData.putFloat(gray);
          }
      }
  }

  private MappedByteBuffer loadModelFile() throws IOException {
    Activity activity = getCurrentActivity();
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }
}
...