Недавно я пытался вычислить расстояния до 2 верхних соседей в Python Numba следующим образом
@jit(nopython=True)
def _latent_dim_kernel(data, pointers, indices, nrange, sampling_percentage = 1):
pdists_t2 = np.zeros((nrange, 2))
for a in range(nrange):
rct = 0
for b in range(nrange):
if np.random.random() > 1- sampling_percentage:
if a == b:
continue
r1 = _get_sparse_row(a, data, pointers, indices)
r2 = _get_sparse_row(b, data, pointers, indices)
dist = np.linalg.norm(r2 - r1)
if rct > 1:
if pdists_t2[a,0] > dist:
pdists_t2[a,0] = dist
elif pdists_t2[a,1] > dist:
pdists_t2[a,1] = dist
else:
pdists_t2[a,rct] = dist
rct += 1
return pdists_t2
Данные, указатели и индексы: x.data, x.indptr, x. индексы матрицы CSR (scipy). Это работает нормально, однако, значительно медленнее, чем
squareform(pdist(matrix)).sort(axis=1)[:,1:3]
Как я могу ускорить это без дополнительных затрат памяти?
Спасибо!