Итерация через массивный массив, эффективно тестирующий несколько элементов - PullRequest
0 голосов
/ 30 января 2019

У меня есть следующий код, который перебирает 2d массив с именем "m".Работает крайне медленно.Как я могу преобразовать этот код, используя функции numpy, чтобы избежать использования циклов for?

pairs = []
for i in range(size):
    for j in range(size):
        if(i >= j):
            continue
        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])

Ответы [ 2 ]

0 голосов
/ 30 января 2019

Вы можете использовать векторизованный подход, используя NumPy.Идея такова:

  • Сначала инициализируйте матрицу m, а затем создайте m+m.T, что эквивалентно m[i][j] + m[j][i], где m.T - транспонирование матрицы и назовите ее summ
  • np.triu(summ) возвращает верхнюю треугольную часть матрицы (это эквивалентно игнорированию нижней части с использованием continue в вашем коде).Это позволяет избежать явного if(i >= j): в вашем коде.Здесь вы должны использовать k=1, чтобы исключить диагональные элементы.По умолчанию k=0, который включает в себя и диагональные элементы.
  • Затем вы получаете индексы точек, используя np.argwhere, где сумма m+m.T больше чем равна 0,75
  • Затем вы сохраняете эти индексы и соответствующие значенияв списке для последующей обработки / печати. ​​

Проверяемый пример (с использованием небольшого набора случайных данных 3x3)

import numpy as np

np.random.seed(0)
m = np.random.rand(3,3)
summ = m + m.T

index = np.argwhere(np.triu(summ, k=1)>=0.75)

pairs = [(x,y, summ[x,y]) for x,y in index]
print (pairs)
# # [(0, 1, 1.2600725493693163), (0, 2, 1.0403505873343364), (1, 2, 1.537667113848736)]

Дальнейшее улучшение производительности

Я только что разработал еще более быстрый подход для генерации окончательного списка pairs, избегая явных циклов для

pairs = list(zip(index[:, 0], index[:, 1], summ[index[:,0], index[:,1]]))
0 голосов
/ 30 января 2019

Один из способов оптимизировать ваш код - избегать сравнения if (i >= j).Чтобы обойти только нижний треугольник массива без этого сравнения, вы должны запустить внутренний цикл со значения i самого внешнего цикла.Таким образом, вы избегаете size x size if сравнений.

import numpy as np
size = 5000
m = np.random.rand(size, size)
pairs = []


for i in range(size):
    for j in range(i , size):

        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...