Numpy - извлечение уникальных значений из матрицы расстояний - PullRequest
0 голосов
/ 12 июня 2018

Входные данные:

У меня есть следующее distance_matrix:

  [[1.        , 0.14285714, 0.25      , 0.83333333, 0.63636364],
   [0.14285714, 1.        , 0.33333333, 0.84615385, 0.66666667],
   [0.25      , 0.33333333, 1.        , 0.76923077, 0.58333333],
   [0.83333333, 0.84615385, 0.76923077, 1.        , 0.69230769],
   [0.63636364, 0.66666667, 0.58333333, 0.69230769, 1.        ]]

Текущие результаты:

np.where(distane_matrix <= 0.25) возвращает следующий вывод:

(array([0, 0, 1, 2]), array([1, 2, 0, 0]))

Желаемые результаты:

(array([0, 0]), array([1, 2]))

Объяснение:

Чтобы выразить это словами, поскольку я знаю, что:

  1. [0,1] имеет то же значение, что и [1,0]
  2. [0,2] имеет то же значение, что и [2,0]
  3. [0,1] и [0,2] удовлетворяюттребования np.where()

Я не хочу, чтобы [1,0] и [2,0] были возвращены на выходе, так как это избыточная информация.Каков наилучший способ сделать это?

Ответы [ 4 ]

0 голосов
/ 13 июня 2018

вместо np.triu, просто используйте np.triu_indicesd в качестве матрицы расстояний:

def dist_thr(d, thr = .25):
    i, j = triu_indices(d.shape[0], 1)
    mask = d[i, j] < thr
    return i[mask], j[mask]
0 голосов
/ 13 июня 2018

Если вам нужно решение, не зависящее от критериев выбора, вы можете рассмотреть возможность использования маскированного массива:

import numpy.ma as ma

mat_masked = ma.array(your_mat, mask = np.triu(np.ones(np.shape(your_mat))))

Отсюда вы продолжите, как и раньше

np.where(mat_masked <= 0.25) 

Это имееттакже преимущество в том, что ваши данные остаются нетронутыми.

0 голосов
/ 13 июня 2018

Вот подход с masking -

def get_lower_indices(a, thresh=0.25):
    n = a.shape[0]    
    ra = np.arange(n)
    mask = ra[:,None] < ra
    v = np.flatnonzero(a[mask] <= thresh)
    idx = np.concatenate(( [0], np.arange(n-1,0,-1).cumsum() ))
    c = np.searchsorted(idx, v,'right')-1
    r = v-idx[c]+c+1    
    return c,r

Пробный прогон -

In [116]: a
Out[116]: 
array([[1.        , 0.14285714, 0.25      , 0.83333333, 0.63636364],
       [0.14285714, 1.        , 0.33333333, 0.84615385, 0.66666667],
       [0.25      , 0.33333333, 1.        , 0.76923077, 0.58333333],
       [0.83333333, 0.84615385, 0.76923077, 1.        , 0.69230769],
       [0.63636364, 0.66666667, 0.58333333, 0.69230769, 1.        ]])

In [117]: get_lower_indices(a, thresh=0.25)
Out[117]: (array([0, 0]), array([1, 2]))

Если вы согласны с редактированием верхних треугольных элементов на более высокое значение, которое выиграло 'не быть пойманным в пороговой операции, мы можем сделать что-то в этом направлении -

def get_lower_indices_mask_editing(a, thresh=0.25):
    n = a.shape[0]
    r = np.arange(n)
    a[r[:,None] >= r] = 1
    return np.where(a<=thresh)

Другие подходы:

# @MaxU's soln
def triu_where(d):
    r = np.triu(d, 1)
    return np.where((r>0) & (r<=0.25))

Время -

In [231]: # Setup random array with larger size and no zeros
     ...: np.random.seed(0)
     ...: N = 5000
     ...: data = np.random.rand(N,N)
     ...: data = data.dot(data.T)
     ...: data = (data - data.min())/(data.max() -data.min())
     ...: data[data==0] = 0.1
     ...: np.fill_diagonal(data,1)

# @MaxU's soln
In [232]: %timeit triu_where(data)
10 loops, best of 3: 174 ms per loop

In [233]: %timeit get_lower_indices(data, thresh=0.25)
1 loop, best of 3: 318 ms per loop

In [234]: %timeit get_lower_indices_mask_editing(data, thresh=0.25)
10 loops, best of 3: 150 ms per loop
0 голосов
/ 12 июня 2018

Предполагая, d - матрица указанного расстояния.

Демонстрация:

In [28]: r = np.triu(d, 1)

In [29]: r
Out[29]:
array([[0.        , 0.14285714, 0.25      , 0.83333333, 0.63636364],
       [0.        , 0.        , 0.33333333, 0.84615385, 0.66666667],
       [0.        , 0.        , 0.        , 0.76923077, 0.58333333],
       [0.        , 0.        , 0.        , 0.        , 0.69230769],
       [0.        , 0.        , 0.        , 0.        , 0.        ]])

In [30]: np.where((r>0) & (r<=0.25))
Out[30]: (array([0, 0], dtype=int64), array([1, 2], dtype=int64))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...