Если вы действительно хотите, чтобы numba выполняла быстро, вам нужно jit
функцию в режиме nopython
, в противном случае numba может вернуться в режим объекта, который медленнее (и может быть довольно медленным).
Однако ваша функция не может быть скомпилирована (начиная с версии 0.43.1 numba) в режиме nopython, потому что:
- аргумент
dtype
для np.empty
. np.float
- это просто Pythons float
и будет переведено NumPy (но не numba) в np.float_
. Если вы используете нумбу, вы должны использовать это.
- В numba отсутствует поддержка строк. Поэтому строка
types[k] == 'float64'
не будет компилироваться.
Первая проблема тривиально исправлена. Что касается второй проблемы: вместо того, чтобы пытаться заставить сравнения строк работать, просто предоставьте логический массив. Использование логического массива и вычисление одного логического значения для истины также будет значительно быстрее, чем сравнение до 7 символов. Особенно, если это в самом внутреннем цикле!
Так что это может выглядеть так:
import numpy as np
import numba as nb
@nb.njit
def pairwise_numba(X, is_float_type):
m = X.shape[0]
n = X.shape[1]
D = np.empty((int(m * (m - 1) / 2), 1), dtype=np.float64) # corrected dtype
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if is_float_type[k]:
tmp = X[i, k] - X[j, k]
d += tmp * tmp
else:
if X[i, k] != X[j, k]:
d += 1.
D[ind] = np.sqrt(d)
ind += 1
return D.reshape(1, -1)[0]
dists = pairwise_numba(vectors, types == 'float64') # pass in the boolean array
Однако вы можете упростить логику, если вы объедините scipy.spatial.distances.pdist
для типов с плавающей точкой с логикой numba для подсчета неравных категорий:
from scipy.spatial.distance import pdist
@nb.njit
def categorial_sum(X):
m = X.shape[0]
n = X.shape[1]
D = np.zeros(int(m * (m - 1) / 2), dtype=np.float64) # corrected dtype
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if X[i, k] != X[j, k]:
d += 1.
D[ind] = d
ind += 1
return D
def pdist_with_categorial(vectors, types):
where_float_type = types == 'float64'
# calculate the squared distance of the float values
distances_squared = pdist(vectors[:, where_float_type], metric='sqeuclidean')
# sum the number of mismatched categorials and add that to the distances
# and then take the square root
return np.sqrt(distances_squared + categorial_sum(vectors[:, ~where_float_type]))
Это не будет значительно быстрее, но это значительно упростило логику в функции numba.
Тогда вы также можете избежать создания дополнительных массивов, передав квадратные расстояния функции numba:
@nb.njit
def add_categorial_sum_and_sqrt(X, D):
m = X.shape[0]
n = X.shape[1]
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if X[i, k] != X[j, k]:
d += 1.
D[ind] = np.sqrt(D[ind] + d)
ind += 1
return D
def pdist_with_categorial(vectors, types):
where_float_type = types == 'float64'
distances_squared = pdist(vectors[:, where_float_type], metric='sqeuclidean')
return add_categorial_sum_and_sqrt(vectors[:, ~where_float_type], distances_squared)