Подсчитайте количество ненулевых значений в массиве Numpa - PullRequest
0 голосов
/ 22 февраля 2019

Очень просто.Я пытаюсь подсчитать количество ненулевых значений в массиве в NumPy jit, скомпилированном с помощью Numba (njit()).Следующее, что я пробовал, не разрешено Numba.

  1. a[a != 0].size
  2. np.count_nonzero(a)
  3. len(a[a != 0])
  4. len(a) - len(a[a == 0])

Я не хочу использовать для циклов, если есть еще более быстрый, более питонский и элегантный способ.

Для того комментатора, который хотел увидеть пример полного кода...

import numpy as np
from numba import njit

@njit()
def n_nonzero(a):
    return a[a != 0].size

Ответы [ 4 ]

0 голосов
/ 22 февраля 2019

В случае, если вам нужно действительно быстро для больших массивов, вы можете даже использовать numbas prange для параллельной обработки счета (для небольших массивов это будет медленнее из-за издержек параллельной обработки).

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

Обратите внимание, что когда вы используете numba, вы обычно хотите записывать свои циклы, потому что именно это numba действительно очень хорошо умеет оптимизировать.

Я фактически сравнил это с другими решениями, упомянутыми здесь.(используя мой модуль Python simple_benchmark):

enter image description here

Код для воспроизведения:

import numpy as np
from numba import njit, prange

@njit
def n_nonzero(a):
    return a[a != 0].size

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

@njit() 
def methodB(a): 
    return (a!=0).sum()

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

@njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

from simple_benchmark import benchmark

args = {}
for exp in range(2, 20):
    size = 2**exp
    arr = np.random.random(size)
    arr[arr < 0.3] = 0.0
    args[size] = arr

b = benchmark(
    funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
    arguments=args,
    argument_name='array size',
    warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)
0 голосов
/ 22 февраля 2019

Не уверен, что здесь я допустил ошибку, но это кажется в 6 раз быстрее:

# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)  

In [41]: @njit() 
    ...: def methodA(a): 
    ...:     return len(np.nonzero(a)[0])                                                                                           

# Call and check result
In [42]: methodA(a)                                                                                 
Out[42]: 666644445

In [43]: %timeit methodA(a)                                                                         
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [44]: @njit() 
    ...: def methodB(a): 
    ...:     return (a!=0).sum()                                                                                         

# Call and check result    
In [45]: methodB(a)                                                                                 
Out[45]: 666644445

In [46]: %timeit methodB(a)                                                                         
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
0 голосов
/ 22 февраля 2019

Вы также можете рассмотреть, ну, подсчет ненулевых значений:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

Я знаю, что это кажется неправильным, но потерпите меня:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Это на самом деле быстрее, чемnp.count_nonzero, который может быть довольно медленным по некоторым причинам:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
0 голосов
/ 22 февраля 2019

Вы можете использовать np.nonzero и указать длину:

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

count_non_zero(np.array([0,1,0,1]))
# 2
...