Обходной путь для аргумента оси numpy np.all; совместимость с Numba - PullRequest
0 голосов
/ 19 апреля 2020

У меня есть функция, которая, учитывая массив numpy координат xy, фильтрует те из них, которые лежат в поле со стороной L

import numpy as np
from numba import njit

np.random.seed(65238758)

L = 10
N = 1000
xy = np.random.uniform(0, 50, (N, 2))
box = np.array([
    [0,0],  # lower-left
    [L,L]  # upper-right
]) 

def sinjit(xy, box):
    mask = np.all(np.logical_and(xy >= box[0], xy <= box[1]), axis=1)
    return xy[mask]

Если я запускаю эту функцию, она возвращает правильный результат:

sinjit(xy, box)

Output: array([[5.53200522, 7.86890708],
       [4.60188554, 9.15249881],
       [9.072563  , 5.6874726 ],
       [4.48976127, 8.73258166],
       ...
       [6.29683131, 5.34225758],
       [2.68057087, 5.09835442],
       [5.98608603, 4.87845464],
       [2.42049857, 6.34739079],
       [4.28586677, 5.79125413]])

Но, поскольку я хочу ускорить эту задачу в al oop с помощью numba, существует проблема совместимости с аргументом "axis" в функции np.all (она не реализована в no python режим). Итак, мой вопрос: можно ли каким-либо образом избежать такого аргумента? любой обходной путь?

...