Пересечение двух прямоугольников с NumPy - PullRequest
0 голосов
/ 18 марта 2020

У меня есть следующая функция, чтобы найти пересечение двух прямоугольников. Это немного медленно, я не знаю, из-за условия OR или операторов >, <. Интересно, есть ли способ улучшить производительность функции is_intersect()? Может быть с NumPy? Или Cython?

import numpy as np

def is_intersect(rect1, rect2):
    xmin1, xmax1, ymin1, ymax1 = rect1
    xmin2, xmax2, ymin2, ymax2 = rect2
    if xmin1 > xmax2 or xmax1 < xmin2:
        return False
    if ymin1 > ymax2 or ymax1 < ymax2:
        return False
    return True

N_ELEMS = 100000000
rects1 = np.random.rand(N_ELEMS,4)
rects2 = np.random.rand(N_ELEMS,4)

temp_dct = dict()

for i in range(N_ELEMS):
    rect1 = rects1[i,:]
    rect2 = rects2[i,:]
    if is_intersect(rect1, rect2):
        temp_dct[i] = True

Я не могу извлечь выгоду из результатов кэширования, так как точки будут инкрементными, то есть один прямоугольник будет перемещаться в пространстве (никогда в одном и том же месте). В этом примере я использовал функцию NumPy random(), но это не так для моего реального использования. Я вызову функцию is_intersect() 100 000 000 раз или более.

1 Ответ

3 голосов
/ 18 марта 2020

Вы можете улучшить производительность, избегая for l oop, используя векторизованное сравнение и np.any:

result = (1 - np.any([rects1[:,0] > rects2[:,1], 
                      rects1[:,1] < rects2[:,0], 
                      rects1[:,2] > rects2[:,3], 
                      rects1[:,3] < rects2[:,2]], 
                     axis=0)).astype(bool)

У вас нет словаря, но вы можете получить доступ к result по индексу.

Производительность с элементами 100M:

import numpy as np
import timeit

N_ELEMS = 100_000_000
rects1 = np.random.rand(N_ELEMS,4)
rects2 = np.random.rand(N_ELEMS,4)

start_time = timeit.default_timer()
result = (1 - np.any([rects1[:,0] > rects2[:,1], 
                      rects1[:,1] < rects2[:,0], 
                      rects1[:,2] > rects2[:,3], 
                      rects1[:,3] < rects2[:,2]], 
                     axis=0)).astype(bool)

print(timeit.default_timer() - start_time)
2.9162093999999996
...