Я бы хотел ускорить cdist
между двумя numpy.ndarray
, используя numba
следующим образом:
import numpy as np
from numba import njit, jit
from scipy.spatial.distance import cdist
import time
@njit
def dist_scipy(a, b):
d = cdist(a, b, 'euclidean')
d = np.transpose(d)
sorted_d = np.sort(d)
sorted_ind = np.argsort(d)
return sorted_d, sorted_ind
def get_a_b(r=10**4,c=10** 1):
a = np.random.uniform(-1, 1, (r, c)).astype('f')
b = np.random.uniform(-1, 1, (r, c)).astype('f')
return a,b
if __name__ == "__main__":
a, b = get_a_b()
st_t = time.time()
dist_scipy(a,b)
print('it took {} s'.format(time.time()-st_t))
В python2
, после $ pip install numba-scipy
я получаю следующую ошибку :
Traceback (most recent call last):
File "stackoverflow_Q.py", line 31, in <module>
dist_scipy(a,b)
File "/usr/local/lib/python2.7/dist-packages/numba/dispatcher.py", line 420, in _compile_for_args
raise e
File "/usr/local/lib/python2.7/dist-packages/numba_scipy/special/overloads.py", line 12
f = signatures.name_and_types_to_pointer[(name, *signature)]
^
SyntaxError: invalid syntax
И в python3, после $ conda install -c conda-forge numba numba-scipy
, я получаю следующую ошибку:
Traceback (most recent call last):
File "numba_scipy_test.py", line 31, in <module>
dist_scipy(a,b)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 420, in _compile_for_args
raise e
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 353, in _compile_for_args
return self.compile(tuple(argtypes))
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 768, in compile
cres = self._compiler.compile(args, return_type)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 77, in compile
status, retval = self._compile_cached(args, return_type)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 91, in _compile_cached
retval = self._compile_core(args, return_type)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 109, in _compile_core
pipeline_class=self.pipeline_class)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 550, in compile_extra
args, return_type, flags, locals)
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 281, in __init__
targetctx.refresh()
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/base.py", line 281, in refresh
self.load_additional_registries()
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/cpu.py", line 80, in load_additional_registries
numba.entrypoints.init_all()
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/entrypoints.py", line 24, in init_all
func()
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/__init__.py", line 12, in _init_extension
from . import special
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/__init__.py", line 1, in <module>
from . import overloads as _overloads
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/overloads.py", line 4, in <module>
from . import signatures
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/signatures.py", line 376, in <module>
('pdtr', numba.types.float64, numba.types.float64): ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)(get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_0pdtr')),
File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/extending.py", line 406, in get_cython_function_address
return _import_cython_function(module_name, function_name)
ValueError: No function '__pyx_fuse_0pdtr' found in __pyx_capi__ of 'scipy.special.cython_special'