Синоним Numba для добавления значений с использованием логической индексации в numpy - PullRequest
0 голосов
/ 01 мая 2020

Я пытаюсь создать более эффективный код, но застрял при реализации следующей версии Numba:

import numpy as np

a = np.array([[0, 0, 0, 0],
              [0, 0, 0, 0]])

bool_idx = np.array([True, False, False, True])

a[0, bool_idx] += 3
a

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

К сожалению, я получаю ошибку при переносе этого кода в функцию с помощью numba:

@njit
def add_to_arr(a, idx, arr_bool, add):
    arr[idx, arr_bool] += 3
    return arr

add_to_arr(a=a, idx=0, arr_bool=bool_idx, add=3)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int32, 2d, C), (int64, array(bool, 1d, C)))

1 Ответ

0 голосов
/ 01 мая 2020

Кажется, что в этом случае Numba допускает только расширенную индексацию для первого измерения массива. Мы можем переписать функцию (также исправляя небольшую опечатку), чтобы приспособиться к этому, просто используя транспонирование и обращая индекс:

@njit 
def add_to_arr(a, idx, arr_bool, add): 
    a.T[arr_bool, idx] += 3 
    return a 

add_to_arr(a, 0, bool_idx, 3)     

Это работает для меня, давая:

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

Документация говорит, что расширенная индексация разрешена только в одном измерении, но не указывает, что это должно быть первое измерение, поэтому это может быть ошибкой.

...