Почему результаты в браузере, использующем Google AutoML, так сильно отличаются от экспортированного файла tflite в моем приложении android? - PullRequest
0 голосов
/ 23 февраля 2020

Я тренирую модель, используя Google AutoML. Затем я экспортировал модель TFLite для запуска в приложении android.

С одним и тем же изображением, используемым в приложении и в браузере (Google AutoML дает вам возможность протестировать вашу модель в браузере), я получаю очень точный ответ в браузере, но в своем приложении я получаю не очень точный ответ.

Есть мысли о том, почему это может происходить?

Соответствующий код:

  public void predict(String imagePath, final Promise promise) {
    imgData.order(ByteOrder.nativeOrder());
    labelProbArray = new byte[1][RESULTS_TO_SHOW];

    try {

      labelList = loadLabelList();
    } catch (Exception ex) {
      ex.printStackTrace();
    }
    Log.w("FIND_ ", imagePath);

    Bitmap bitmap = BitmapFactory.decodeFile(imagePath);
    convertBitmapToByteBuffer(bitmap);


    try {
      tflite = new Interpreter(loadModelFile());
    } catch (Exception ex) {
      Log.w("FIND_exception in loading tflite", "1");      
    }

    tflite.run(imgData, labelProbArray);
    promise.resolve(getResult());
  }


  private WritableNativeArray getResult() {
    WritableNativeArray result = new WritableNativeArray();

    //ArrayList<JSONObject> result = new ArrayList<JSONObject>();
    try {

      for (int i = 0; i < RESULTS_TO_SHOW; i++) {
          WritableNativeMap map = new WritableNativeMap();

          map.putString("label", Integer.toString(i));
          float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
          map.putString("prob", String.valueOf(output));
          result.pushMap(map);

          Log.w("FIND_label ", Integer.toString(i));
          Log.w("FIND_prob ", String.valueOf(labelProbArray[0][i]));
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    return result;
  }

  private List<String> loadLabelList() throws IOException {
    Activity activity = getCurrentActivity();
      List<String> labelList = new ArrayList<String>();
      BufferedReader reader =
              new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
      String line;
      while ((line = reader.readLine()) != null) {
          labelList.add(line);
      }
      reader.close();
      return labelList;
  }

  private void convertBitmapToByteBuffer(Bitmap bitmap) {
      if (imgData == null) {
          return;
      }
      imgData.rewind();
      Log.w("FIND_bitmap width ", String.valueOf(bitmap.getWidth()));
      Log.w("FIND_bitmap height ", String.valueOf(bitmap.getHeight()));

      bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
      // Convert the image to floating point.
      int pixel = 0;
      //long startTime = SystemClock.uptimeMillis();
      for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
          for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
              final int val = intValues[pixel++];

              imgData.put((byte) ( ((val >> 16) & 0xFF)));
              imgData.put((byte) ( ((val >> 8) & 0xFF)));
              imgData.put((byte) ( (val & 0xFF)));
              Log.w("FIND_i", String.valueOf(i));
              Log.w("FIND_j", String.valueOf(j));
          }
      }
  }

  private MappedByteBuffer loadModelFile() throws IOException {
    Activity activity = getCurrentActivity();
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }

Одна из вещей, которые я не знаю, если это правильно, это как я корректирую свои выходные данные в вероятности. В методе getResult() я пишу это:

float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;

Я делю на 255, чтобы получить число от 0 до 1, но я не знаю, если это правильный.

Модель .tflite представляет собой квантованную модель uint8, если это помогает.

...