Обновление: добавлена оптимизация blas
Существует несколько простых и очень эффективных оптимизаций:
(1) векторизация, векторизация!Не так сложно векторизовать практически все в этом коде.См. Ниже.
(2) используйте правильный поиск, то есть необычное индексирование, а не np.take
(3) используйте Cholesky decomp.С blas dtrmm
мы можем использовать его треугольную структуру
И вот код.Просто добавьте его в конец кода OP (под EDIT 2).Если вы не очень терпеливы, вы, вероятно, также захотите закомментировать строки lut = generate_lut()
и result = calculate_distance(lut, rgb)
и все ссылки на cv2.Я также добавил случайную строку в x
, чтобы сделать ее ковариационную матрицу неособой.
class Full_Model(Model):
ch = np.linalg.cholesky(Model.inverse_pooled_covariance)
chx = Model.x_mean@ch
def rgb2something_vectorized(rgb):
return np.sqrt(np.sum(((rgb - Full_Model.x_mean)@Full_Model.ch)**2, axis=-1))
from scipy.linalg import blas
def rgb2something_blas(rgb):
*shp, nchan = rgb.shape
return np.sqrt(np.einsum('...i,...i', *2*(blas.dtrmm(1, Full_Model.ch.T, rgb.reshape(-1, nchan).T, 0, 0, 0, 0, 0).T - Full_Model.chx,))).reshape(shp)
def generate_lut_vectorized():
return rgb2something_vectorized(np.transpose(np.indices((256, 256, 256))))
def generate_lut_blas():
rng = np.arange(256)
arr = np.empty((256, 256, 256, 3))
arr[0, ..., 0] = rng
arr[0, ..., 1] = rng[:, None]
arr[1:, ...] = arr[0]
arr[..., 2] = rng[:, None, None]
return rgb2something_blas(arr)
def calculate_distance_vectorized(lut, input_image):
return lut[input_image[..., 2], input_image[..., 1], input_image[..., 0]]
# test code
def random_check_lut(lut):
"""Because the original lut generator is excruciatingly slow,
we only compare a random sample, using the original code
"""
levels = 256
levels2 = levels**2
lut = lut.ravel()
levels_range = range(0, levels)
for r, g, b in np.random.randint(0, 256, (1000, 3)):
assert np.isclose(lut[r + (g * levels) + (b * levels2)], rgb2something(r, g, b))
import time
td = []
td.append((time.time(), 'create lut vectorized'))
lutv = generate_lut_vectorized()
td.append((time.time(), 'create lut using blas'))
lutb = generate_lut_blas()
td.append((time.time(), 'lookup using np.take'))
res = calculate_distance(lutv, rgb)
td.append((time.time(), 'process on the fly (no lookup)'))
resotf = rgb2something_vectorized(rgb)
td.append((time.time(), 'process on the fly (blas)'))
resbla = rgb2something_blas(rgb)
td.append((time.time(), 'lookup using fancy indexing'))
resv = calculate_distance_vectorized(lutv, rgb)
td.append((time.time(), None))
print("sanity checks ... ", end='')
assert np.allclose(res, resotf) and np.allclose(res, resv) \
and np.allclose(res, resbla) and np.allclose(lutv, lutb)
random_check_lut(lutv)
print('all ok\n')
t, d = zip(*td)
for ti, di in zip(np.diff(t), d):
print(f'{di:32s} {ti:10.3f} seconds')
Пример выполнения:
sanity checks ... all ok
create lut vectorized 1.116 seconds
create lut using blas 0.917 seconds
lookup using np.take 0.398 seconds
process on the fly (no lookup) 0.127 seconds
process on the fly (blas) 0.069 seconds
lookup using fancy indexing 0.064 seconds
Мы видим, что лучший поиск превосходитлучший расчет на летуТем не менее, пример может переоценить стоимость поиска, поскольку случайные пиксели предположительно менее дружественны по отношению к кэшу, чем естественные изображения.
Оригинальный ответ (возможно, еще полезный для некоторых)
Если rgb2 что-то не может быть векторизовано,и вы хотите обработать одно типичное изображение, тогда вы можете получить приличное ускорение, используя np.unique
.
Если rgb2something стоит дорого и нужно обрабатывать несколько изображений, тогда unique
можно объединить с кэшированием, котороеудобно делать только с использованием (10) * --- (незначительного) камня преткновения: аргументы должны быть хэшируемыми.Оказывается, что изменение в коде, которое это принуждает (приведение rgb-массивов к 3-байтовым строкам), повышает производительность.
Использование таблицы полного просмотра стоит только в том случае, если у вас огромное количествопиксели, покрывающие большинство оттенков.В этом случае самым быстрым способом является использование простого индексирования для фактического поиска.
import numpy as np
import time
import functools
def rgb2something(rgb):
# waste some time:
np.exp(0.1*rgb)
return rgb.mean()
@functools.lru_cache(None)
def rgb2something_lru(rgb):
rgb = np.frombuffer(rgb, np.uint8)
# waste some time:
np.exp(0.1*rgb)
return rgb.mean()
def apply_to_img(img):
shp = img.shape
return np.reshape([rgb2something(x) for x in img.reshape(-1, shp[-1])], shp[:2])
def apply_to_img_lru(img):
shp = img.shape
return np.reshape([rgb2something_lru(x) for x in img.ravel().view('S3')], shp[:2])
def apply_to_img_smart(img, print_stats=True):
shp = img.shape
unq, bck = np.unique(img.reshape(-1, shp[-1]), return_inverse=True, axis=0)
if print_stats:
print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
return np.array([rgb2something(x) for x in unq])[bck].reshape(shp[:2])
def apply_to_img_smarter(img, print_stats=True):
shp = img.shape
unq, bck = np.unique(img.ravel().view('S3'), return_inverse=True)
if print_stats:
print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
return np.array([rgb2something_lru(x) for x in unq])[bck].reshape(shp[:2])
def make_full_lut():
x = np.empty((3,), np.uint8)
return np.reshape([rgb2something(x) for x[0] in range(256)
for x[1] in range(256) for x[2] in range(256)],
(256, 256, 256))
def make_full_lut_cheat(): # for quicker testing lookup
i, j, k = np.ogrid[:256, :256, :256]
return (i + j + k) / 3
def apply_to_img_full_lut(img, lut):
return lut[(*np.moveaxis(img, 2, 0),)]
from scipy.misc import face
t0 = time.perf_counter()
bw = apply_to_img(face())
t1 = time.perf_counter()
print('naive ', t1-t0, 'seconds')
t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru first time ', t1-t0, 'seconds')
t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru second time ', t1-t0, 'seconds')
t0 = time.perf_counter()
bw = apply_to_img_smart(face(), False)
t1 = time.perf_counter()
print('using unique: ', t1-t0, 'seconds')
rgb2something_lru.cache_clear()
t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru first: ', t1-t0, 'seconds')
t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru second:', t1-t0, 'seconds')
t0 = time.perf_counter()
lut = make_full_lut_cheat()
t1 = time.perf_counter()
print('creating full lut: ', t1-t0, 'seconds')
t0 = time.perf_counter()
bw = apply_to_img_full_lut(face(), lut)
t1 = time.perf_counter()
print('using full lut: ', t1-t0, 'seconds')
print()
apply_to_img_smart(face())
import Image
Image.fromarray(bw.astype(np.uint8)).save('bw.png')
Пример выполнения:
naive 6.8886632949870545 seconds
lru first time 1.7458112589956727 seconds
lru second time 0.4085628940083552 seconds
using unique: 2.0951434450107627 seconds
unique and lru first: 2.0168916099937633 seconds
unique and lru second: 0.3118703299842309 seconds
creating full lut: 151.17599205300212 seconds
using full lut: 0.12164952099556103 seconds
total no pixels 786432
no unique pixels 134105