Как эффективно сравнить каждую пару строк в 2D матрице? - PullRequest
1 голос
/ 01 мая 2019

Я работаю над подпрограммой, где мне нужно обработать каждую строку матрицы и найти, какие другие строки содержат в текущей строке. Для иллюстрации того, когда строка содержит другую, рассмотрим матрицу 3x3, как показано ниже:

[[1, 0, 1], 

 [0, 1, 0], 

 [1, 0, 0]]

Здесь строка 1 содержит строку 3, поскольку каждый элемент в строке 1 больше или равен строке 3 , но строка 1 не содержит строки 2.

Я пришел к следующему решению, но оно очень медленное из-за цикла for (размер матрицы около 6000x6000).

for i in range(no_of_rows):
    # Here Adj is the 2D matrix 
    contains = np.argwhere(np.all(Adj[i] >= Adj, axis = 1))

Не могли бы вы сообщить мне, если это возможно сделать более эффективно?

Ответы [ 2 ]

1 голос
/ 01 мая 2019

Из-за размера ваших матриц и требований вашей проблемы, я думаю, итерация неизбежна.Вы не можете использовать широковещательную рассылку, так как она взорвет вашу память, поэтому вам нужно работать с существующим массивом построчно.Однако вы можете использовать numba и njit, чтобы значительно ускорить этот процесс по сравнению с подходом чистого Python.


import numpy as np
from numba import njit


@njit
def zero_out_contained_rows(a):
    """
    Finds rows where all of the elements are
    equal or smaller than all corresponding
    elements of anothe row, and sets all
    values in the row to zero

    Parameters
    ----------
    a: ndarray
      The array to modify

    Returns
    -------
    The modified array

    Examples
    --------
    >>> zero_out_contained_rows(np.array([[1, 0, 1], [0, 1, 0], [1, 0, 0]]))
    array([[1, 0, 1],
            [0, 1, 0],
            [0, 0, 0]])
    """
    x, y = a.shape

    contained = np.zeros(x, dtype=np.bool_)

    for i in range(x):
        for j in range(x):
            if i != j and not contained[j]:
                equal = True
                for k in range(y):
                    if a[i, k] < a[j, k]:
                        equal = False
                        break
                contained[j] = equal

    a[contained] = 0

    return a

Это позволяет отслеживать, используется ли строка в другомстрока.Это предотвращает множество ненужных сравнений путем короткого замыкания, прежде чем окончательно уничтожить строки, содержащиеся в других, с помощью 0.


По сравнению с вашей первоначальной попыткой использования итерации, это также улучшение скоростиas также обрабатывает обнуление правильных строк.


a = np.random.randint(0, 2, (6000, 6000))

%timeit zero_out_contained_rows(a)
1.19 s ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Я обновлю время, как только ваша попытка завершится (в настоящее время ~ 10 минут).

0 голосов
/ 01 мая 2019

Если у вас матрица 6000x6000, чем вам нужно (6000 * 6000 - 6000) / 2 = 17997000 вычислений.

Вместо использования np.triu_indices, вы можете попробовать использовать генератор для верхнего треугольника вашегоматрица - это должно уменьшить потребление памяти.Попробуйте это, может быть, это поможет ..

def indxs(lst):
   for i1, el1 in enumerate(lst):
      for el2 in lst[i1:][1:]:
         yield (el1, el2)
...