Pythonical получает подмножество строк из матрицы numpy на основе условия в каждой строке и во всех столбцах - PullRequest
0 голосов
/ 23 февраля 2020

Учитывая следующую матрицу,

In [0]: a = np.array([[1,2,9,4,2,5],[4,5,1,4,2,4],[2,3,6,7,8,9],[5,6,7,4,3,6]])
Out[0]: 
array([[1, 2, 9, 4, 2, 5],
       [4, 5, 1, 4, 2, 4],
       [2, 3, 6, 7, 8, 9],
       [5, 6, 7, 4, 3, 6]])

Я хочу получить индексы строк, которые имеют 9 в качестве члена. Это,

idx = [0,2]

В настоящее время я делаю это,

def myf(x):
    if any(x==9):
        return True
    else:
        return False

aux = np.apply_along_axis(myf, axis=1, arr=a)
idx = np.where(aux)[0]

И я получаю желаемый результат.

In [1]: idx
Out[1]: array([0, 2], dtype=int64)

Но этот метод очень медленный (имеется в виду, может быть, есть более быстрый путь) и, конечно, не очень pythoni c.

Как я могу сделать это более чистым, более pythoni c, но в основном более эффективным способом?

Обратите внимание, что этот вопрос близок к этот но здесь я хочу применить условие ко всей строке.

Ответы [ 3 ]

1 голос
/ 24 февраля 2020

Вы можете объединить np.argwhere и np.any:

np.argwhere(np.any(a==9,axis=1))[:,0]
1 голос
/ 24 февраля 2020

Используйте np.argwhere, чтобы найти индексы, где a==9, и используйте 0-й столбец этих индексов, чтобы индексировать:

In [171]: a = np.array([[1,2,9,4,2,5],[4,5,1,4,2,4],[2,3,6,7,8,9],[5,6,7,4,3,6]])
     ...: 
     ...: indices = np.argwhere(a==9)
     ...: a[indices[:,0]]
Out[171]: 
array([[1, 2, 9, 4, 2, 5],
       [2, 3, 6, 7, 8, 9]])

... или, если вам просто нужны номера строк, просто сохраните indices[:,0]. Если 9 может появляться более одного раза в строке и вы не хотите, чтобы в списке были повторяющиеся строки, вы можете использовать np.unique для фильтрации результатов (ничего не делает для этого примера):

In [173]: rows = indices[:,0]

In [174]: np.unique(rows)
Out[174]: array([0, 2])
0 голосов
/ 24 февраля 2020

Вы можете попробовать np.nonzero и unique

Проверка на 9

np.unique((a == 9).nonzero()[0])

Out[356]: array([0, 2], dtype=int64)

Проверка на 6

np.unique((a == 6).nonzero()[0])

Out[358]: array([2, 3], dtype=int64)

Проверка на 8

np.unique((a == 8).nonzero()[0])

Out[359]: array([2], dtype=int64)

На несуществующий номер вернуть пустой список

np.unique((a == 88).nonzero()[0])

Out[360]: array([], dtype=int64)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...