Ошибка в numba @njit при индексации массива numpy - PullRequest
0 голосов
/ 20 апреля 2020

Я пытаюсь создать с помощью numba функцию, которая возвращает массив numpy, оцененный по другому массиву. Я опубликую простой код без njit:

import numpy as np
import numba as nb

def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    return eva[mask]

Он работает правильно, как и ожидалось :

>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
array([6, 7, 8])

Тем не менее, когда я пытаюсь скомпилировать его с помощью numba в режиме python (@njit), выдается ошибка

@nb.njit
def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    return eva[mask]

>>> prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-9-111474f08921> in <module>
----> 1 prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))

~/.local/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

~/.local/lib/python3.7/site-packages/numba/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

~/.local/lib/python3.7/site-packages/numba/six.py in reraise(tp, value, tb)
    666             value = tp()
    667         if value.__traceback__ is not tb:
--> 668             raise value.with_traceback(tb)
    669         raise value
    670 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int64, 1d, C), list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
In definition 10:
    All templates rejected with literals.
In definition 11:
    All templates rejected without literals.
In definition 12:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /home/donielix/.local/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <ipython-input-8-1b5c9f1a65d5> (6)
[2] During: typing of static-get-item at <ipython-input-8-1b5c9f1a65d5> (6)

File "<ipython-input-8-1b5c9f1a65d5>", line 6:
def prueba(arr, eva):
    <source elided>
        mask.append(arr[i])
    return eva[mask]
    ^

Так что мой вопрос почему этот простой код дает неожиданную ошибку? И как мне обойти эту проблему?

Ответы [ 2 ]

1 голос
/ 20 апреля 2020

Непосредственно из документации:

Также поддерживается подмножество расширенной индексации: разрешен только один расширенный индекс, и он должен быть одномерным массивом (его можно комбинировать с произвольное количество базовых c индексов). https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array -доступ

Поэтому, чтобы ваш код работал, вы должны преобразовать mask в numpy array:

@nb.njit
def prueba(arr, eva):
    mask = []
    for i in range(len(arr)):
        mask.append(arr[i])
    mask_as_array = np.array(mask)
    return eva[mask_as_array]

prueba(np.array([1,2,3]), np.array([5,6,7,8,9,10]))
0 голосов
/ 20 апреля 2020

Ваша индексация с использованием numpy:

In [181]: a, b = np.array([1,2,3]), np.array([5,6,7,8,9,10])                                           
In [182]: b[a]                                                                                         
Out[182]: array([6, 7, 8])
In [183]: def foo(arr, eva): 
     ...:     return eva[arr] 
     ...:                                                                                              
In [184]: foo(a,b)                                                                                     
Out[184]: array([6, 7, 8])
In [186]: timeit foo(a,b)                                                                              
350 ns ± 9.98 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Попытка воспроизвести его (и, возможно, ускорить) с помощью numba:

In [185]: import numba                                                                                 

In [187]: @numba.njit 
     ...: def foo1(arr,eva): 
     ...:     return eva[arr] 
     ...:                                                                                              
In [188]: foo1(a,b)                                                                                    
Out[188]: array([6, 7, 8])
In [189]: timeit foo1(a,b)                                                                             
968 ns ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [190]: @numba.njit 
     ...: def foo2(arr,eva): 
     ...:     res = np.empty(len(arr), eva.dtype) 
     ...:     for i in range(len(arr)): 
     ...:         res[i] = b[a[i]] 
     ...:     return res 

In [191]: foo2(a,b)                                                                                    
Out[191]: array([6, 7, 8])
In [192]: timeit foo2(a,b)                                                                             
941 ns ± 7.91 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [193]: @numba.njit 
     ...: def foo2(arr,eva): 
     ...:     res = np.empty(len(arr), eva.dtype) 
     ...:     for i,v in enumerate(a): 
     ...:         res[i] = b[v] 
     ...:     return res 

In [194]: foo2(a,b)                                                                                    
Out[194]: array([6, 7, 8])
In [195]: timeit foo2(a,b)                                                                             
941 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Не много смысла пытаться заменить базовую c numpy функциональность на numba.

Кто-то с большим опытом numba может улучшить это.

edit

Как я изначально заметил, numba не любит индексировать со списком. Преобразование списка в массив работает:

In [196]: @numba.njit 
     ...: def prueba(arr, eva): 
     ...:     mask = [] 
     ...:     for i in range(len(arr)): 
     ...:         mask.append(arr[i]) 
     ...:     mask = np.array(mask) 
     ...:     return eva[mask] 
     ...:                                                                                              
In [197]: prueba(a,b)                                                                                  
Out[197]: array([6, 7, 8])
In [198]: timeit prueba(a,b)                                                                           
1.5 µs ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
...