Решение с использованием Numba
В некоторых случаях это очень легко сделать, если поддерживаются все используемые вами функции. В вашем коде win = signal.detrend(win, type = 'linear')
- это часть, которую вы должны реализовать в Numba, потому что эта функция не поддерживается.
Реализация отклонения в Numba
Если вы посмотрите на исходный код отклонения и извлечете соответствующие части для вашей проблемы, это можетвыглядит так:
@nb.njit()
def detrend(w):
Npts=w.shape[0]
A=np.empty((Npts,2),dtype=w.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
coef, resids, rank, s = np.linalg.lstsq(A, w.T)
out=w.T- np.dot(A, coef)
return out.T
Я также реализовал более быстрое решение для np.max(np.isnan(win)) == 1
@nb.njit()
def isnan(win):
for i in range(win.shape[0]):
for j in range(win.shape[1]):
if np.isnan(win[i,j]):
return True
return False
Основная функция
Как я использовал Numbaздесь распараллеливание очень простое, просто prange на внешнем цикле и
import numpy as np
import numba as nb
@nb.njit(parallel=True)
def RMSH_det_nb(DEM, w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
for i in nb.prange(w+1,nrows-w):
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend(win)
z = win.flatten()
nz = z.size
rootms = np.sqrt(1 / (nz - 1) * np.sum((z-np.mean(z))**2))
rms[i,j] = rootms
return rms
Время (маленький пример)
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb(DEM, w) #approx. 55 times faster
#29 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Время длябольшие массивы
w = 10
DEM=np.random.rand(1355, 1165).astype(np.float32)
%timeit res2=RMSH_det_nb(DEM, w)
#6.63 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[Редактировать] Реализация с использованием нормальных уравнений
Переопределенная система
Этот метод имеет меньшая числовая точность . Хотя это решение намного быстрее.
@nb.njit()
def isnan(win):
for i in range(win.shape[0]):
for j in range(win.shape[1]):
if win[i,j]==np.nan:
return True
return False
@nb.njit()
def detrend(w):
Npts=w.shape[0]
A=np.empty((Npts,2),dtype=w.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
coef, resids, rank, s = np.linalg.lstsq(A, w.T)
out=w.T- np.dot(A, coef)
return out.T
@nb.njit()
def detrend_2(w,T1,A):
T2=np.dot(A.T,w.T)
coef=np.linalg.solve(T1,T2)
out=w.T- np.dot(A, coef)
return out.T
@nb.njit(parallel=True)
def RMSH_det_nb_normal_eq(DEM,w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
Npts=w*2-1
A=np.empty((Npts,2),dtype=DEM.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
T1=np.dot(A.T,A)
nz = Npts**2
for i in nb.prange(w+1,nrows-w):
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend_2(win,T1,A)
rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
rms[i,j] = rootms
return rms
Время
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb_normal_eq(DEM,w)
#7.97 ms ± 89.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Оптимизированное решение с использованием нормальных уравнений
Временные массивы используются повторно, чтобы избежать дорогостоящих выделений памяти, и используется пользовательская реализация для умножения матриц. Это рекомендуется только для очень маленьких матриц, в большинстве других случаев np.dot (sgeemm) будет намного быстрее.
@nb.njit()
def matmult_2(A,B,out):
for j in range(B.shape[1]):
acc1=nb.float32(0)
acc2=nb.float32(0)
for k in range(B.shape[0]):
acc1+=A[0,k]*B[k,j]
acc2+=A[1,k]*B[k,j]
out[0,j]=acc1
out[1,j]=acc2
return out
@nb.njit(fastmath=True)
def matmult_mod(A,B,w,out):
for j in range(B.shape[1]):
for i in range(A.shape[0]):
acc=nb.float32(0)
acc+=A[i,0]*B[0,j]+A[i,1]*B[1,j]
out[j,i]=acc-w[j,i]
return out
@nb.njit()
def detrend_2_opt(w,T1,A,Tempvar_1,Tempvar_2):
T2=matmult_2(A.T,w.T,Tempvar_1)
coef=np.linalg.solve(T1,T2)
return matmult_mod(A, coef,w,Tempvar_2)
@nb.njit(parallel=True)
def RMSH_det_nb_normal_eq_opt(DEM,w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
Npts=w*2-1
A=np.empty((Npts,2),dtype=DEM.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
T1=np.dot(A.T,A)
nz = Npts**2
for i in nb.prange(w+1,nrows-w):
Tempvar_1=np.empty((2,Npts),dtype=DEM.dtype)
Tempvar_2=np.empty((Npts,Npts),dtype=DEM.dtype)
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend_2_opt(win,T1,A,Tempvar_1,Tempvar_2)
rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
rms[i,j] = rootms
return rms
Время
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb_normal_eq_opt(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb_normal_eq_opt(DEM,w)
#4.66 ms ± 87.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Времена для isnan
Эта функция является совершенно другой реализацией. Это намного быстрее, если NaN находится в начале массива, но в любом случае, даже если и нет, то есть некоторое ускорение. Я сравнил его с небольшими массивами (приблизительный размер окна) и большим размером, предложенным @ user3666197.
case_1=np.full((20,20),np.nan)
case_2=np.full((20,20),0.)
case_2[10,10]=np.nan
case_3=np.full((20,20),0.)
case_4 = np.full( ( int( 1E4 ), int( 1E4 ) ),np.nan)
case_5 = np.ones( ( int( 1E4 ), int( 1E4 ) ) )
%timeit np.any(np.isnan(case_1))
%timeit np.any(np.isnan(case_2))
%timeit np.any(np.isnan(case_3))
%timeit np.any(np.isnan(case_4))
%timeit np.any(np.isnan(case_5))
#2.75 µs ± 73.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#2.75 µs ± 46.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#2.76 µs ± 32.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#81.3 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#86.7 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit isnan(case_1)
%timeit isnan(case_2)
%timeit isnan(case_3)
%timeit isnan(case_4)
%timeit isnan(case_5)
#244 ns ± 5.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#357 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#475 ns ± 9.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#235 ns ± 0.933 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#58.8 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)