Выполнить групповую операцию на двумерном массиве - PullRequest
1 голос
/ 27 июня 2019

У меня есть двумерный массив (фактически матрица сходства), по которому мне нужно вычислить среднее по блокам. Например, со следующей матрицей:

sima = np.array([[1,0.8,0.7,0.3,0.1,0.5],
                 [0.8,1,0.1,0.5,0.2,0.5],
                 [0.7,0.1,1,0.1,0.3,0.9],
                 [0.3,0.5,0.1,1,0.8,0.5],
                 [0.1,0.2,0.3,0.8,1,0.5],
                 [0.5,0.5,0.9,0.5,0.5,1]])

И метки вектора:

labels = np.array([1,1,1,2,2,3])

Это означает, что первые три строки матрицы (а также столбцы столбцов, поскольку матрица подобия симметрична) соответствуют кластеру 1, следующие 2 соответствуют кластеру 2, а последние соответствуют кластер 3.

Мне нужно вычислить среднее количество блоков в sima, соответствующее меткам в labels. Дает следующий вывод:

0.69 0.25 0.63 
0.25 0.90 0.50 
0.63 0.50 1.00

Пока у меня есть рабочее решение, использующее двойную петлю для меток и замаскированных массивов:

labels_matrix = np.tile(np.array(labels), (len(labels), 1))
output = pd.DataFrame(np.zeros(shape = (3,3)))

for i in range(3):
  for j in range(3):
    mask = (labels_matrix != j+1) | (labels_matrix.T != i+1)
    output.loc[i,j] = np.mean(np.mean(np.ma.array(sima, mask = mask)))

Этот код выдает правильный вывод, но моя фактическая матрица 50kx50k, и этот код требует вечных вычислений. Как я мог сделать это быстрее?

Примечание: мне нужен другой порядок величины скорости, поэтому я ожидаю, что использования трюков, таких как симметрия матрицы подобия, будет недостаточно.

1 Ответ

2 голосов
/ 27 июня 2019

Для сортированных этикеток мы можем использовать np.add.reduceat -

In [62]: idx = np.flatnonzero(np.r_[True,labels[:-1] != labels[1:],True])

In [63]: c = np.diff(idx)

In [64]: sums = np.add.reduceat(np.add.reduceat(sima,idx[:-1],axis=0),idx[:-1],axis=1)

In [65]: sums/(c[:,None]*c)
Out[65]: 
array([[0.68888889, 0.25      , 0.63333333],
       [0.25      , 0.9       , 0.5       ],
       [0.63333333, 0.5       , 1.        ]])
...