Расширенная альтернатива индексирования для функции njit numba - PullRequest
2 голосов
/ 10 марта 2020

Учитывая следующий минимальный воспроизводимый пример:

import numpy as np
from numba import jit

# variable number of dimensions
n_t = 8
# q is just a partition of n
q_ddl = 2
n_ddl = 3

np.random.seed(42)
df = np.random.rand(q_ddl*n_t,q_ddl*n_t)

# index array
# ddl_nl is a set of np.arange(n_ddl), ex: [0,1] ; [0,2] or even [0] ...
ddl_nl = np.array([0,1])
ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))

@jit(nopython=True)
def foo(df,ij):
    out = np.zeros((n_t,n_ddl,n_ddl))
    for i in range(0,n_t):     
        d_i = np.zeros((n_ddl,n_ddl))
        # (q_ddl,q_ddl) non zero values into (n_ddl,n_ddl) shape
        d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
        # to check possible solutions
        out[i,...] = d_i
    return out


out_foo = foo(df,ij)

Функция foo работает хорошо, когда @jit(nopython=True) отключена, но выдает следующую ошибку, когда включена:

TypeError: unsupported array index type array(int64, 2d, C) in UniTuple(array(int64, 2d, C) x 2)

, что произошло во время операции вещания d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]. Затем я попытался сгладить двумерные индексные массивы ij чем-то вроде d_i[ij[0].ravel(), ij[1].ravel()] = df[i::n_t,i::n_t].ravel(), что дает мне тот же вывод, но теперь еще одну ошибку:

NotImplementedError: only one advanced index supported

Так что я наконец попытался избежать этого, используя классическая структура с 2 вложенными for петлями:

tmp = df[i::n_t,i::n_t]
for k,r in enumerate(ddl_nl):
    for l,c in enumerate(ddl_nl):
        d_i[r,c] = tmp[k,l]

, которая работает с включенным декоратором и дает ожидаемый результат. совместимые альтернативы для этой операции вещания numpy 2d-массива, которые мне здесь не хватает? Любая помощь будет принята с благодарностью.

Ответы [ 2 ]

1 голос
/ 10 марта 2020

Избегайте необычного индексирования

Также избегайте использования глобальных переменных (они жестко запрограммированы во время компиляции) и сохраняйте свой код как можно более простым (просто означает только петли росы, если / еще, ...). Если массив ddl_nl действительно создается только с использованием np.arange, даже этот массив вообще не нужен.

Пример

import numpy as np
from numba import jit

@jit(nopython=True)
def foo_nb(df,n_ddl,n_t,ddl_nl):
    out = np.zeros((n_t,n_ddl,n_ddl))
    for i in range(0,n_t):
        for ii in range(ddl_nl.shape[0]):
            ind_1=ddl_nl[ii]
            for jj in range(ddl_nl.shape[0]):
                ind_2=ddl_nl[jj]
                out[i,ind_1,ind_2] = df[i+ii*n_t,i+jj*n_t]
    return out

Сроки

#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
1 голос
/ 10 марта 2020

Проверка некоторых ваших значений:

In [446]: ddl_nl = np.array([0,1]) 
     ...: ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij')) 
     ...:                                                                                      
In [447]: ij                                                                                   
Out[447]: 
array([[[0, 0],
        [1, 1]],

       [[0, 1],
        [0, 1]]])
In [448]: n_t = 8 
     ...: q_ddl = 2 
     ...: n_ddl = 3                                                                            
In [449]: d_i = np.zeros((n_ddl,n_ddl))                                                        
In [450]: d_i                                                                                  
Out[450]: 
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])
In [451]: d_i[ij[0], ij[1]]                                                                    
Out[451]: 
array([[0., 0.],
       [0., 0.]])

попробуйте больше диагнозов c d_i:

In [452]: d_i = np.arange(9).reshape(3,3)                                                      
In [453]: d_i[ij[0], ij[1]]                                                                    
Out[453]: 
array([[0, 1],
       [3, 4]])
In [454]: d_i[:2,:2]                                                                           
Out[454]: 
array([[0, 1],
       [3, 4]])

Почему вы используете расширенную индексацию, когда базовая c нарезка будет работать?

Я не пробовал это с numba, но у него больше шансов на работу. Тем не менее, перечисленное l oop может быть столь же быстрым. У меня недостаточно опыта работы с numba, чтобы сказать наверняка.

===

Очевидно, вы выполнили операцию numpy, которую numba не поддерживает:

In [456]: numba.__version__                                                                    
Out[456]: '0.43.0'
In [457]: @numba.jit 
     ...: def foo(arr): 
     ...:     return arr[[1,2,3],[1,2,3]] 
     ...:                                                                                      
In [458]: foo(np.eye(4))                                                                       
Out[458]: array([1., 1., 1.])
In [459]: @numba.njit 
     ...: def foo(arr): 
     ...:     return arr[[1,2,3],[1,2,3]] 
     ...:                                                                                      
In [460]: foo(np.eye(4))    
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(list(int64) x 2))

Это не необычно. numba не претендует на полное покрытие Python или numpy.

Но с numba нам не нужно избегать итерации. На самом деле это лучше всего при замене операции, которую numpy не может обойтись без итерации.

In [465]: @numba.njit 
     ...: def foo(arr): 
     ...:     out = np.zeros((3,), arr.dtype) 
     ...:     for n, (i,j) in enumerate(zip([1,2,3],[1,2,3])): 
     ...:         out[n] = arr[i,j] 
     ...:     return out 

In [466]: foo(np.eye(4))                                                                       
Out[466]: array([1., 1., 1.])
In [467]: timeit foo(np.eye(4))                                                                
6.85 µs ± 28.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [468]: np.eye(4)[[1,2,3],[1,2,3]]                                                           
Out[468]: array([1., 1., 1.])
In [469]: timeit np.eye(4)[[1,2,3],[1,2,3]]                                                    
13.3 µs ± 31.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
...