Как я могу преобразовать Tensor в растровое изображение на PyTorch Mobile? - PullRequest
0 голосов
/ 28 января 2020

Я нашел это решение (https://itnext.io/converting-pytorch-float-tensor-to-android-rgba-bitmap-with-kotlin-ffd4602a16b6), но когда я попытался преобразовать таким образом, я обнаружил, что размер inputTensor.dataAsFloatArray больше bitmap.width*bitmap.height. Как работает преобразование тензора в массив с плавающей точкой или есть какой-либо другой возможный способ преобразования тензора Pytorch в растровое изображение?

val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
    bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB
)

// Float array size is 196608 when width and height are 256x256 = 65536

val res = floatArrayToGrayscaleBitmap(inputTensor.dataAsFloatArray, bitmap.width, bitmap.height)


fun floatArrayToGrayscaleBitmap (
    floatArray: FloatArray,
    width: Int,
    height: Int,
    alpha :Byte = (255).toByte(),
    reverseScale :Boolean = false
) : Bitmap {

    // Create empty bitmap in RGBA format (even though it says ARGB but channels are RGBA)
    val bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
    val byteBuffer = ByteBuffer.allocate(width*height*4)
    Log.d("App", floatArray.size.toString() + " " + (width * height * 4).toString())

    // mapping smallest value to 0 and largest value to 255
    val maxValue = floatArray.max() ?: 1.0f
    val minValue = floatArray.min() ?: 0.0f
    val delta = maxValue-minValue
    var tempValue :Byte

    // Define if float min..max will be mapped to 0..255 or 255..0
    val conversion = when(reverseScale) {
        false -> { v: Float -> ((v-minValue)/delta*255).toByte() }
        true -> { v: Float -> (255-(v-minValue)/delta*255).toByte() }
    }

    // copy each value from float array to RGB channels and set alpha channel
    floatArray.forEachIndexed { i, value ->
        tempValue = conversion(value)
        byteBuffer.put(4*i, tempValue)
        byteBuffer.put(4*i+1, tempValue)
        byteBuffer.put(4*i+2, tempValue)
        byteBuffer.put(4*i+3, alpha)
    }

    bmp.copyPixelsFromBuffer(byteBuffer)

    return bmp
}
...