Я обучил модель, используя тензорный поток, а затем преобразовал ее в формат tenorflow-lite. Затем я поместил модель в приложение Android и использовал интерпретатор tensorflowlite для вывода, и результатом было не что иное, как полностью черное изображение. Я перенес код из python в Java как есть, но все равно получаю мусор.
Есть идеи, где я могу здесь ошибиться.
Python Код:
def preprocess(img):
return (img / 255. - 0.5) * 2
def deprocess(img):
return (img + 1) / 2
img_size = 256
frozen_model_filename = os.path.join('model/tflite', 'model.tflite')
image_1 = cv2.resize(imread(image_1), (img_size, img_size))
X_1 = np.expand_dims(preprocess(image_1), 0)
X_1 = X_1.astype(np.float32)
image_2 = cv2.resize(imread(image_2), (img_size, img_size))
X_2 = np.expand_dims(preprocess(image_2), 0)
X_2 = X_2.astype(np.float32)
interpreter = tf.lite.Interpreter(model_path=frozen_model_filename)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], X_1)
interpreter.set_tensor(input_details[1]['index'], X_2)
interpreter.invoke()
Output = interpreter.get_tensor(output_details[0]['index'])
Output = deprocess(Output)
imsave('result_tflite.jpg', Output[0])
Соответствующий Java код для Android платформы:
private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
Bitmap resized = Bitmap.createScaledBitmap(bitmap, IMAGE_SIZE, IMAGE_SIZE, false);
ByteBuffer byteBuffer;
if(isQuant) {
byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
} else {
byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
}
byteBuffer.order(ByteOrder.nativeOrder());
int[] intValues = new int[IMAGE_SIZE * IMAGE_SIZE];
resized.getPixels(intValues, 0, resized.getWidth(), 0, 0, resized.getWidth(), resized.getHeight());
int pixel = 0;
byteBuffer.rewind();
for (int i = 0; i < IMAGE_SIZE; ++i) {
for (int j = 0; j < IMAGE_SIZE; ++j) {
final int val = intValues[pixel++];
if(isQuant){
byteBuffer.put((byte) ((val >> 16) & 0xFF));
byteBuffer.put((byte) ((val >> 8) & 0xFF));
byteBuffer.put((byte) (val & 0xFF));
} else {
byteBuffer.putFloat((((val >> 16) & 0xFF) - 0.5f) * 2.0f);
byteBuffer.putFloat((((val >> 8) & 0xFF) - 0.5f) * 2.0f);
byteBuffer.putFloat((((val) & 0xFF ) - 0.5f) * 2.0f);
}
}
}
return byteBuffer;
}
private Bitmap getOutputImage(ByteBuffer output){
output.rewind();
int outputWidth = IMAGE_SIZE;
int outputHeight = IMAGE_SIZE;
Bitmap bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);
int [] pixels = new int[outputWidth * outputHeight];
for (int i = 0; i < outputWidth * outputHeight; i++) {
int a = 0xFF;
float r = (output.getFloat() + 1) / 2.0f;
float g = (output.getFloat() + 1) / 2.0f;
float b = (output.getFloat() + 1) / 2.0f;
pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | (int) b;
}
bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight);
return bitmap;
}
private void runInference(){
ByteBuffer byteBufferX1 = convertBitmapToByteBuffer(bitmap_x1);
ByteBuffer byteBufferX2 = convertBitmapToByteBuffer(bitmap_x2);
Object[] inputs = {byteBufferX1, byteBufferX2};
ByteBuffer byteBufferOutput;
if(isQuant) {
byteBufferOutput = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
} else {
byteBufferOutput = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
}
byteBufferOutput.order(ByteOrder.nativeOrder());
byteBufferOutput.rewind();
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, byteBufferOutput);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
ByteBuffer out = (ByteBuffer) outputs.get(0);
Bitmap outputBitmap = getOutputImage(out);
// outputBitmap is just a full black image
}