Фильтровать массив в Python3 / Numpy и вернуть индексы - PullRequest
6 голосов
/ 27 марта 2012

Есть ли в Python3 / Numpy какая-либо встроенная функция, которая фильтрует массив и возвращает индексы оставшихся элементов? Нечто похожее на numpy.argsort для сортировки. У меня есть фильтр, устанавливающий минимальные и максимальные пороги - все значения ниже / выше минимальных / максимальных значений должны быть отфильтрованы.

Я видел функцию Python filter , но я не вижу способа извлечь индексы, используя ее.

Отредактировано: Много полезной информации в ответах, спасибо!

Как указал @SvenMarnach, маски достаточно:

mask = (min_value < a) & (a < max_value)

Теперь я должен применить эту маску к другим массивам той же формы, что и a, но не уверен, как лучше это сделать ...

Ответы [ 5 ]

5 голосов
/ 27 марта 2012

Команда numpy.where вернет индексы массива после того, как вы применили к ним маску. Например:

import numpy as np
A = np.array([1,2,3,6,2])
np.where(A>2)

дает:

(array([2, 3]),)

Более сложный пример:

A = np.arange(27).reshape(3,3,3)
np.where( (A>10) & (A<15) )

дает:

(array([1, 1, 1, 1]), array([0, 1, 1, 1]), array([2, 0, 1, 2]))

Я соглашусь с @SvenMarnach, обычно вам не нужны индексы.

4 голосов
/ 27 марта 2012

Вы можете получить индексы элементов в одномерном массиве a, которые больше min_value и меньше max_value с помощью

indices = ((min_value < a) & (a < max_value)).nonzero()[0]

Обычно эти индексы обычно не нужны, но вы можете работать с маской более эффективно

mask = (min_value < a) & (a < max_value)

Эта маска является логическим массивом такой же формы, как и a.

Редактировать : Если у вас есть массив b той же формы, что и a, вы можете извлечь элементы b, соответствующие True записям в mask, с помощью

b[mask]
3 голосов
/ 27 марта 2012

Не имеет прямого отношения к вашему вопросу, но filter() является частью набора из трех функций map(), filter() и reduce(), которые разрешают обработку списка в стиле Python.

  • map(mapping_function, input_list) принимает функцию с одним аргументом и списком, применяет функцию к каждому элементу списка по очереди и возвращает выходной список в качестве результата. Это более или менее эквивалентно пониманию списка [mapping_function(item) for item in input_list].

  • filter(filter_function, input_list) возвращает список элементов из input_list, для которых filter_function вернул True. Эквивалент списка составляет [item for item in items if filter_function(item)].

  • reduce(combining_function, input_list) многократно комбинирует соседние пары элементов в списке ввода, пока не останется только одно значение. Например, сумма списка чисел может быть выражена как reduce(operator.add, numbers).

Функциональность map() и filter() обеспечивается списками в Python (именно поэтому функции map и filter используются не очень часто.)

reduce() - это одна из тех вещей, которая не предлагает себя в качестве интуитивного ответа на ... что-либо. Почти всегда проще написать цикл, который объясняет, почему вы не видите его часто.

1 голос
/ 27 марта 2012

Мне очень нравится Sven , и на самом деле, numpy.where делает именно то, что вы хотите, как напомнил мне Крюк.Но в основном потому, что я это уже написал, вот другой подход, просто чтобы проиллюстрировать несколько трюков.my_filter может быть любой функцией, возвращающей логический массив той же формы, что и входные данные:

def my_filter(a):
    return (10 < a) & (a < 40)

a_mask = my_filter(a)
indices = [ind[a_mask] for ind in numpy.indices(a.shape)]

Например:

>>> a = numpy.arange(100).reshape((10, 10))
>>> def my_filter(a):
...     return (min_value < a) & (a < max_value)
... 
>>> a_mask = my_filter(a)
>>> [ind[a_mask] for ind in numpy.indices(a.shape)]
[array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3]), 
 array([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9])]
0 голосов
/ 16 апреля 2019

Вам необязательно нужен NumPy, вы можете сделать это с помощью простого понимания списка и enumerate:

a = [1, 2, 3, 6, 2] 
[i for i, v in enumerate(a) if v < 2]

возвращается:

[2, 3]

...