Есть ли способ быстрее, чем np.isin для большого массива? - PullRequest
2 голосов
/ 25 мая 2020

Для большого массива (n> 1e8) существует ли более быстрый способ, чем np.isin, для проверки наличия одинаковых элементов?

Я пробовал несколько методов, например pandas isin, cython, но все они занимают больше времени, чем пример np.isin

: (Проверьте, каждый ли элемент одномерного массива также присутствует во втором массиве)

num = int(1e8)
a = np.random.rand(int(num))
b = np.random.rand(int(num))

ref=time.time()
ainb = np.isin(a,b)
print(a[ainb])
print(time.time()-ref,'sec')

>>> [0.23591019 0.46102523]
>>> 65.45570135116577 sec

Ответы [ 2 ]

3 голосов
/ 25 мая 2020

Если вам нужен добавочный (для вашего варианта использования), но, возможно, более быстрая замена np.isin(), вы можете использовать Python set() для проверки и ускорения явного цикла в Numba:

import numpy as np
import numba as nb


@nb.jit
def is_in_set_nb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in range(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)

Обратите внимание, что есть некоторый (дешевый) дополнительный код, чтобы заставить его работать для массивов N-dim, который вы, вероятно, могли бы пропустить, если вам нужен только 1D.

Это можно было бы даже ускорить, добавив дополнительные распараллеливание:

import numpy as np
import numba as nb


@nb.jit(parallel=True)
def is_in_set_pnb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in nb.prange(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)

Это происходит намного быстрее, чем np.isin(), set() пересечение и решение is_in_set() без ускорения Numba:

def is_in_set(a, b):
    set_b = set(b)
    return np.array([x in set_b for x in a])

с размером ввода десять миллионов элементов:

n = 10 ** 7
k = n // 3
np.random.seed(0)
# note: I used `int`s because I wanted to be able to control the collisions
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 3.94 s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 814 ms per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 740 ms per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 7.69 s per loop
%timeit set(a).intersection(b)  # not a drop-in replacement
# 1 loop, best of 3: 6.79 s per loop
%timeit set(a) & set(b)  # not a drop-in replacement
# 1 loop, best of 3: 8.98 s per loop

и с сотнями миллионов элементов (последние два подхода закончились тем, что заполнили всю память и поэтому опущены):

n = 10 ** 8
k = n // 3
np.random.seed(0)
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 1min 4s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 13.1 s per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 11.4 s per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 2min 5s per loop

Добавление большего времени для меньших входов, но все комбинации длин a и b:

funcs = np.isin, is_in_set_nb, is_in_set_pnb
sep = '    '
print(f'({"n=len(a)":>9s},{"m=len(b)":>9s})', end=sep)
for func in funcs:
    print(f'{func.__name__:15s}', end=sep)
print()
I, J = 7, 7
for i in range(I):
    for j in range(J):
        n = 10 ** i
        m = 10 ** j
        a = np.random.randint(0, m * n, n)
        b = np.random.randint(0, m * n, m)
        print(f'({n:9d},{m:9d})', end=sep)
        for func in funcs:
            result = %timeit -q -o func(a, b)
            print(f'{result.best * 1e3:12.3f} ms', end=sep)
        print()
( n=len(a), m=len(b))    isin               is_in_set_nb       is_in_set_pnb      
(        1,        1)           0.011 ms           0.001 ms           0.047 ms    
(        1,       10)           0.048 ms           0.001 ms           0.023 ms    
(        1,      100)           0.050 ms           0.002 ms           0.027 ms    
(        1,     1000)           0.102 ms           0.007 ms           0.041 ms    
(        1,    10000)           0.766 ms           1.028 ms           1.122 ms    
(        1,   100000)           9.717 ms           3.426 ms           3.356 ms    
(        1,  1000000)         105.154 ms          43.642 ms          40.734 ms    
(       10,        1)           0.010 ms           0.001 ms           0.023 ms    
(       10,       10)           0.030 ms           0.001 ms           0.023 ms    
(       10,      100)           0.053 ms           0.002 ms           0.027 ms    
(       10,     1000)           0.100 ms           0.007 ms           0.055 ms    
(       10,    10000)           0.961 ms           1.031 ms           1.154 ms    
(       10,   100000)           9.772 ms           3.595 ms           3.761 ms    
(       10,  1000000)         105.802 ms          54.260 ms          50.265 ms    
(      100,        1)           0.010 ms           0.001 ms           0.024 ms    
(      100,       10)           0.030 ms           0.002 ms           0.025 ms    
(      100,      100)           0.054 ms           0.002 ms           0.026 ms    
(      100,     1000)           0.105 ms           0.008 ms           0.045 ms    
(      100,    10000)           0.751 ms           1.076 ms           1.158 ms    
(      100,   100000)           9.824 ms           3.253 ms           3.329 ms    
(      100,  1000000)         105.697 ms          57.993 ms          55.285 ms    
(     1000,        1)           0.012 ms           0.005 ms           0.028 ms    
(     1000,       10)           0.038 ms           0.006 ms           0.029 ms    
(     1000,      100)           0.119 ms           0.007 ms           0.033 ms    
(     1000,     1000)           0.180 ms           0.014 ms           0.063 ms    
(     1000,    10000)           0.821 ms           1.074 ms           1.169 ms    
(     1000,   100000)           9.920 ms           3.392 ms           3.532 ms    
(     1000,  1000000)         104.666 ms          57.845 ms          54.603 ms    
(    10000,        1)           0.020 ms           0.041 ms           0.092 ms    
(    10000,       10)           0.089 ms           0.088 ms           0.158 ms    
(    10000,      100)           0.967 ms           0.112 ms           0.182 ms    
(    10000,     1000)           1.017 ms           0.161 ms           0.249 ms    
(    10000,    10000)           1.633 ms           1.137 ms           1.283 ms    
(    10000,   100000)          10.754 ms           3.027 ms           3.302 ms    
(    10000,  1000000)         101.926 ms          48.062 ms          49.117 ms    
(   100000,        1)           0.071 ms           0.409 ms           0.455 ms    
(   100000,       10)           0.575 ms           0.916 ms           0.803 ms    
(   100000,      100)          16.304 ms           1.201 ms           0.940 ms    
(   100000,     1000)          15.185 ms           1.566 ms           1.181 ms    
(   100000,    10000)          15.914 ms           1.454 ms           1.252 ms    
(   100000,   100000)          23.719 ms           4.820 ms           4.313 ms    
(   100000,  1000000)         119.668 ms          56.863 ms          54.570 ms    
(  1000000,        1)           0.774 ms           4.347 ms           3.407 ms    
(  1000000,       10)           6.207 ms           8.793 ms           5.957 ms    
(  1000000,      100)         178.498 ms          13.104 ms           8.544 ms    
(  1000000,     1000)         169.022 ms          16.198 ms          10.283 ms    
(  1000000,    10000)         177.986 ms          13.243 ms           8.973 ms    
(  1000000,   100000)         177.989 ms          19.856 ms          13.898 ms    
(  1000000,  1000000)         283.207 ms          97.118 ms          84.332 ms 

Это показывает, что Numba и распараллеливание весьма полезны для больших входных данных и становятся немного менее эффективными t для меньших входов. Тем не менее, они по-прежнему превосходят np.isin() в большинстве из приведенных выше тестов.

0 голосов
/ 25 мая 2020

Я не уверен, подходит ли это вам, но пересечение множеств (&) выглядит быстрее. Обратной стороной является то, что вы получите только элементы, а не их индексы.

import numpy as np
import time

num = int(1e7)
a = np.random.rand(int(num))
b = np.random.rand(int(num))

a_set = set(a)
b_set = set(b)

Пересечение наборов времени:

ref=time.time()
print(a_set & b_set)
print(time.time()-ref,'sec')

>> set()
0.6486170291900635 sec

Время np.isin():

ref=time.time()
ainb = np.isin(a, b)
print( a[ainb] )
print(time.time()-ref,'sec')

>> []
4.06556510925293 sec
...