С этой проблемой сталкиваются из-за неправильного толкования того, как матрицы хранятся в numpy
, а также того, как должен работать процесс цветовой коррекции.
Процесс цветовой коррекции работает путем выполнения умножения матрицодин кортеж RGB с матрицей коррекции цвета (CCM). Таким образом, в основном это матричное умножение 2 матриц измерений (1, 3)
и (3, 3)
. Он работает следующим образом:
RGB * CCM = RGB_Corrected
(1, 3) * (3, 3) = (1, 3)
Этот процесс применяется индивидуально для каждого пикселя RGB.
Теперь, когда процесс цветовой коррекции определен, нам нужно правильно определить наши переменные, т.е. проверить их размеры.
СКК должен быть (3 , 3)
массив numpy, но в настоящее время он определен как массив (1 , 9)
. Это должно быть определено следующим образом, чтобы сделать его двухмерным (обратите внимание на дополнительные квадратные скобки):
ccm = np.array([ [ 1.0234, -0.2969, -0.2266],
[-0.5625, 1.6328, -0.0469],
[-0.0703, 0.2188, 0.6406] ])
Во-вторых, умножение матриц должно выполняться среди матриц совместимых размеров. Для этого изображение необходимо изменить таким образом, чтобы вместо трехмерной матрицы размеров (height, width, channels)
это была двумерная матрица с размерами ( width x height, channels )
. В вашем случае его форма должна быть ( 512 x 512, 3 )
.
. Это может быть достигнуто с помощью функции numpy.ndarray.reshape
следующим образом:
img2 = img.reshape((img.shape[0] * img.shape[1], 3))
Это позволит намвыполнить умножение матриц со следующими размерами.
(262144, 3) * (3, 3) = (262144, 3)
output = np.matmul(img2, ccm)
Далее,мы изменим изображение обратно к исходным размерам и вернем результат.
reshaped_back = output.reshape(img.shape).astype(img.dtype)
Окончательная функция цветовой коррекции может выглядеть следующим образом:
def color_correction(img, ccm):
'''
Input:
img: H*W*3 numpy array, input image
ccm: 3*3 numpy array, color correction matrix
Output:
output: H*W*3 numpy array, output image after color correction
'''
img2 = img.reshape((img.shape[0] * img.shape[1], 3))
output = np.matmul(img2, ccm)
return output.reshape(img.shape).astype(img.dtype)