Оптимизировать эту функцию - numpy проблема вещания - PullRequest
1 голос
/ 06 января 2020

У меня есть функция contains, которая проверяет заданный 2D массив u, если в поле [min, max] содержится каждая строка u. Мне нужно изменить его u, если необходимо, но число значений u всегда будет кратно d (может быть ноль);

Я использую следующий фрагмент кода , Эта функция запускается тысячи раз. Может ли быть создан более быстрый код? Если вы так думаете, какие-либо советы о том, как?

import numpy as np

def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)

# Usage examples : 
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]

contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)

Ответы [ 2 ]

3 голосов
/ 07 января 2020

( EDITED : исправить проблему измерения времени, поднятую @ max9111 в комментариях, и включить модифицированное numexpr решение).

Узкое место в конечном итоге может оказаться в пределах np.all() вызов. Это может быть ускорено с помощью Numba следующим образом:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] < arr[i, j] < b_arr[j]:
                result[i] = False
                break
    return result

Это сравнивается с решением NumPy:

import numpy as np


def contains_np(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)

, которое я немного упростил по сравнению с вашим подходом (у меня есть пропущенные параметры dim и strict, поскольку dim является избыточным, так как он может быть выведен из размеров a_arr или b_arr, в то время как параметр strict не добавляет много к решению, но может быть легко вводится). Я также предполагаю, что входные данные уже всегда являются массивом NumPy.

Кроме того, решение NumPy можно изменить, чтобы использовать numexpr, что приводит к третьему подходу. Это вызовет некоторые накладные расходы, но может ускорить вычисления, например:

import numpy as np
import numexpr as ne


def contains_ne(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
    return np.all(result, axis=1)

Можно получить следующие тесты:

bm

Это показывает, что решение Numba неизменно является самым быстрым. Напротив, использование numexpr представляется бесполезным для диапазона исследуемых параметров.

(доступно полное тестирование здесь )

2 голосов
/ 06 января 2020

Попробуйте, чтобы ускорить, подробнее здесь

from numba import jit

@jit(nopython=True)
def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...