Вывод Tensorflow на android с получением мусорного результата - PullRequest
0 голосов
/ 25 мая 2020

Я обучил модель, используя тензорный поток, а затем преобразовал ее в формат tenorflow-lite. Затем я поместил модель в приложение Android и использовал интерпретатор tensorflowlite для вывода, и результатом было не что иное, как полностью черное изображение. Я перенес код из python в Java как есть, но все равно получаю мусор.

Есть идеи, где я могу здесь ошибиться.

Python Код:

def preprocess(img):
    return (img / 255. - 0.5) * 2

def deprocess(img):
    return (img + 1) / 2

img_size = 256

frozen_model_filename = os.path.join('model/tflite', 'model.tflite')

image_1 = cv2.resize(imread(image_1), (img_size, img_size))
X_1 = np.expand_dims(preprocess(image_1), 0)
X_1 = X_1.astype(np.float32)

image_2 = cv2.resize(imread(image_2), (img_size, img_size))
X_2 = np.expand_dims(preprocess(image_2), 0)
X_2 = X_2.astype(np.float32)


interpreter = tf.lite.Interpreter(model_path=frozen_model_filename)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


interpreter.set_tensor(input_details[0]['index'], X_1)
interpreter.set_tensor(input_details[1]['index'], X_2)
interpreter.invoke()

Output = interpreter.get_tensor(output_details[0]['index'])
Output = deprocess(Output)

imsave('result_tflite.jpg', Output[0])

Соответствующий Java код для Android платформы:

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {

    Bitmap resized = Bitmap.createScaledBitmap(bitmap, IMAGE_SIZE, IMAGE_SIZE, false);

    ByteBuffer byteBuffer;

    if(isQuant) {
        byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBuffer.order(ByteOrder.nativeOrder());

    int[] intValues = new int[IMAGE_SIZE * IMAGE_SIZE];
    resized.getPixels(intValues, 0, resized.getWidth(), 0, 0, resized.getWidth(), resized.getHeight());

    int pixel = 0;
    byteBuffer.rewind();

    for (int i = 0; i < IMAGE_SIZE; ++i) {
        for (int j = 0; j < IMAGE_SIZE; ++j) {
            final int val = intValues[pixel++];
            if(isQuant){
                byteBuffer.put((byte) ((val >> 16) & 0xFF));
                byteBuffer.put((byte) ((val >> 8) & 0xFF));
                byteBuffer.put((byte) (val & 0xFF));
            } else {
                byteBuffer.putFloat((((val >> 16) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val >> 8) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val) & 0xFF ) - 0.5f) *  2.0f);
            }
        }
    }
    return byteBuffer;
}

private Bitmap getOutputImage(ByteBuffer output){
    output.rewind();

    int outputWidth = IMAGE_SIZE;
    int outputHeight = IMAGE_SIZE;
    Bitmap bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);
    int [] pixels = new int[outputWidth * outputHeight];
    for (int i = 0; i < outputWidth * outputHeight; i++) {
        int a = 0xFF;

        float r = (output.getFloat() + 1) / 2.0f;
        float g = (output.getFloat() + 1) / 2.0f;
        float b = (output.getFloat() + 1) / 2.0f;

        pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | (int) b;
    }
    bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight);
    return bitmap;
}

private void runInference(){

    ByteBuffer byteBufferX1 = convertBitmapToByteBuffer(bitmap_x1);
    ByteBuffer byteBufferX2 = convertBitmapToByteBuffer(bitmap_x2);

    Object[] inputs = {byteBufferX1, byteBufferX2};

    ByteBuffer byteBufferOutput;

    if(isQuant) {
        byteBufferOutput = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBufferOutput = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBufferOutput.order(ByteOrder.nativeOrder());
    byteBufferOutput.rewind();

    Map<Integer, Object> outputs = new HashMap<>();
    outputs.put(0, byteBufferOutput);

    interpreter.runForMultipleInputsOutputs(inputs, outputs);
    ByteBuffer out = (ByteBuffer) outputs.get(0);
    Bitmap outputBitmap = getOutputImage(out);

    // outputBitmap is just a full black image
}

1 Ответ

0 голосов
/ 26 мая 2020

Оба интерпретатора Java и Python основаны на реализации C ++, поэтому результаты должны быть одинаковыми. Ошибка должна быть в вашем коде JAVA. Здесь я думаю, вы забыли умножить и разделить на 255.

...