Пересечение двух списков в Numba - PullRequest
1 голос
/ 29 января 2020

Я хотел бы знать самый быстрый способ вычисления пересечения двух списков в функции numba. Просто для пояснения: пример пересечения двух списков:

Input : 
lst1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
lst2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]
Output :
[9, 10, 4, 5]

Проблема в том, что это нужно вычислять в функции numba, и поэтому, например, наборы использовать нельзя. У тебя есть идея? Мой текущий код очень основа c. Я предполагаю, что есть возможности для улучшения.

@nb.njit
def intersection:
   result = []
   for element1 in lst1:
      for element2 in lst2:
         if element1 == element2:
            result.append(element1)
   ....

Ответы [ 3 ]

2 голосов
/ 29 января 2020

Поскольку numba компилирует и запускает ваш код в машинном коде, вы, вероятно, в лучшем случае для такой простой операции. Я провел несколько тестов ниже

@nb.njit
def loop_intersection(lst1, lst2):
    result = []
    for element1 in lst1:
        for element2 in lst2:
            if element1 == element2:
                result.append(element1)
    return result

@nb.njit
def set_intersect(lst1, lst2):
    return set(lst1).intersection(set(lst2))

Resuls

loop_intersection
40.4 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

set_intersect
42 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1 голос
/ 29 января 2020

Я немного поиграл с этим, чтобы попытаться чему-то научиться, понимая, что ответ уже дан. Когда я запускаю принятый ответ, я получаю возвращаемое значение [9, 10, 5, 4, 9]. Я не был уверен, было ли повторное 9 приемлемым или нет. Предполагая, что все в порядке, я запустил пробную версию, используя понимание списка, чтобы увидеть, что это имеет какое-то значение. Мои результаты:

from numba import jit

def createLists():
    l1 = [15, 9, 10, 56, 23, 78, 5, 4, 9]
    l2 = [9, 4, 5, 36, 47, 26, 10, 45, 87]

@jit
def listComp():
    l1, l2 = createLists()
    return [i for i in l1 for j in l2 if i == j]

% timeit listComp () 5,84 микросекунды +/- 10,5 наносекунд

Или, если вы можете использовать Numpy, этот код еще быстрее и удаляет дубликаты "9 "и намного быстрее с подписью Numba.

import numpy as np
from numba import jit, int64

@jit(int64[:](int64[:], int64[:]))
def JitListComp(l1, l2):
    l3 = np.array([i for i in l1 for j in l2 if i == j])
    return np.unique(l3) # and i not in crossSec]

@jit
def CreateList():
    l1 = np.array([15, 9, 10, 56, 23, 78, 5, 4, 9])
    l2 = np.array([9, 4, 5, 36, 47, 26, 10, 45, 87])
    return JitListComp(l1, l2)

CreateList()
Out[39]: array([ 4,  5,  9, 10])

%timeit CreateList()
1.71 µs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
1 голос
/ 29 января 2020

Для этого можно использовать заданную операцию:

def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2))

, затем просто вызвать функцию intersection(lst1,lst2). Это будет самый простой способ.

...