Ошибка при использовании расширения numba-scipy для вычисления cdist в python2 .7 и python3 .X - PullRequest
0 голосов
/ 03 апреля 2020

Я бы хотел ускорить cdist между двумя numpy.ndarray, используя numba следующим образом:

import numpy as np
from numba import njit, jit
from scipy.spatial.distance import cdist
import time

@njit
def dist_scipy(a, b):
    d = cdist(a, b, 'euclidean')
    d = np.transpose(d)
    sorted_d = np.sort(d)
    sorted_ind = np.argsort(d)
    return sorted_d, sorted_ind

def get_a_b(r=10**4,c=10** 1):
    a = np.random.uniform(-1, 1, (r, c)).astype('f')
    b = np.random.uniform(-1, 1, (r, c)).astype('f')
    return a,b

if __name__ == "__main__":
    a, b = get_a_b()
    st_t = time.time()
    dist_scipy(a,b)
    print('it took {} s'.format(time.time()-st_t))

В python2, после $ pip install numba-scipy я получаю следующую ошибку :

Traceback (most recent call last):
  File "stackoverflow_Q.py", line 31, in <module>
    dist_scipy(a,b)
  File "/usr/local/lib/python2.7/dist-packages/numba/dispatcher.py", line 420, in _compile_for_args
    raise e
  File "/usr/local/lib/python2.7/dist-packages/numba_scipy/special/overloads.py", line 12
    f = signatures.name_and_types_to_pointer[(name, *signature)]
                                                    ^
SyntaxError: invalid syntax

И в python3, после $ conda install -c conda-forge numba numba-scipy, я получаю следующую ошибку:

Traceback (most recent call last):
  File "numba_scipy_test.py", line 31, in <module>
    dist_scipy(a,b)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 420, in _compile_for_args
    raise e
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 353, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 768, in compile
    cres = self._compiler.compile(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 77, in compile
    status, retval = self._compile_cached(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 91, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 109, in _compile_core
    pipeline_class=self.pipeline_class)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 550, in compile_extra
    args, return_type, flags, locals)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 281, in __init__
    targetctx.refresh()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/base.py", line 281, in refresh
    self.load_additional_registries()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/cpu.py", line 80, in load_additional_registries
    numba.entrypoints.init_all()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/entrypoints.py", line 24, in init_all
    func()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/__init__.py", line 12, in _init_extension
    from . import special
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/__init__.py", line 1, in <module>
    from . import overloads as _overloads
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/overloads.py", line 4, in <module>
    from . import signatures
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/signatures.py", line 376, in <module>
    ('pdtr', numba.types.float64, numba.types.float64): ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)(get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_0pdtr')),
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/extending.py", line 406, in get_cython_function_address
    return _import_cython_function(module_name, function_name)
ValueError: No function '__pyx_fuse_0pdtr' found in __pyx_capi__ of 'scipy.special.cython_special'
...