Довольно странно, что нумба может быть намного медленнее.
Это не так уж и странно.Когда вы вызываете функции NumPy внутри функции numba, вы вызываете numba-версию этих функций.Они могут быть быстрее, медленнее или такими же быстрыми, как версии NumPy.Вам может повезти или вам не повезло (вам не повезло!).Но даже в функции numba вы все равно создаете много временных файлов, потому что вы используете функции NumPy (один временный массив для результата точки, один для каждого квадрата и суммы, один для точки плюс первая сумма), поэтому вы не пользуетесь преимуществамиВозможности с Numba.
Я использую это неправильно?
По существу: Да.
Мне действительно нужно ускорить это
Хорошо, я попробую.
Начнем с развертывания суммы квадратов по вызовам оси 1:
import numba as nb
@nb.njit
def sum_squares_2d_array_along_axis1(arr):
res = np.empty(arr.shape[0], dtype=arr.dtype)
for o_idx in range(arr.shape[0]):
sum_ = 0
for i_idx in range(arr.shape[1]):
sum_ += arr[o_idx, i_idx] * arr[o_idx, i_idx]
res[o_idx] = sum_
return res
@nb.njit
def euclidean_distance_square_numba_v1(x1, x2):
return -2 * np.dot(x1, x2.T) + np.expand_dims(sum_squares_2d_array_along_axis1(x1), axis=1) + sum_squares_2d_array_along_axis1(x2)
На моем компьютере этоуже в 2 раза быстрее, чем код NumPy, и почти в 10 раз быстрее, чем ваш исходный код Numba.
Как показывает опыт, его увеличение в 2 раза быстрее, чем NumPy, обычно является пределом (по крайней мере, если версия NumPy не слишком сложна)или неэффективно), однако вы можете выжать немного больше, развернув все:
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v2(x1, x2):
f1 = 0.
for i_idx in range(x1.shape[1]):
f1 += x1[0, i_idx] * x1[0, i_idx]
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
val_from_x2 = x2[o_idx, i_idx]
val += (-2) * x1[0, i_idx] * val_from_x2 + val_from_x2 * val_from_x2
val += f1
res[o_idx] = val
return res
Но это только дает улучшение на ~ 10-20% по сравнению с последним подходом.
При этом pВозможно, вы поймете, что можете упростить код (даже если он, вероятно, не ускорит его):
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v3(x1, x2):
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
tmp = x1[0, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[o_idx] = val
return res
Да, это выглядит довольно просто, и на самом деле это не медленнее.
Однако при всем волнении я забыл упомянуть очевидное решение: scipy.spatial.distance.cdist
, которое имеет опцию sqeuclidean
(квадрат евклидова расстояния):
from scipy.spatial import distance
distance.cdist(x1, x2, metric='sqeuclidean')
Это на самом деле не быстрее, чем numba, но доступно без написания вашей собственной функции ...
Тесты
Проверка на правильность и прогрев:
x1 = np.array([[1.,2,3]])
x2 = np.array([[1.,2,3], [2,3,4], [3,4,5], [4,5,6], [5,6,7]])
res1 = euclidean_distance_square(x1, x2)
res2 = euclidean_distance_square_numba_original(x1, x2)
res3 = euclidean_distance_square_numba_v1(x1, x2)
res4 = euclidean_distance_square_numba_v2(x1, x2)
res5 = euclidean_distance_square_numba_v3(x1, x2)
np.testing.assert_array_equal(res1, res2)
np.testing.assert_array_equal(res1, res3)
np.testing.assert_array_equal(res1[0], res4)
np.testing.assert_array_equal(res1[0], res5)
np.testing.assert_almost_equal(res1, distance.cdist(x1, x2, metric='sqeuclidean'))
Время:
x1 = np.random.random((1, 512))
x2 = np.random.random((1000000, 512))
%timeit euclidean_distance_square(x1, x2)
# 2.09 s ± 54.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_original(x1, x2)
# 10.9 s ± 158 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v1(x1, x2)
# 907 ms ± 7.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v2(x1, x2)
# 715 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v3(x1, x2)
# 731 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit distance.cdist(x1, x2, metric='sqeuclidean')
# 706 ms ± 4.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Примечание. Если у вас есть массивы целых чисел, вы можете изменить жестко закодированный 0.0
в функциях numba на 0
.