NumPy: как отфильтровать первые оси многомерного массива по некоторому условию на элементы - PullRequest
0 голосов
/ 21 апреля 2020

Рассмотрим следующий ndarray lm -

In [135]: lm
Out[135]:
array([[[15,  7],
        [ 2,  3],
        [ 0,  4]],

       [[ 8, 12],
        [ 6,  5],
        [17, 10]],

       [[16, 13],
        [30,  1],
        [14,  9]]])

In [136]: lm.shape
Out[136]: (3, 3, 2)

Я хочу отфильтровать элементы первых осей (lm[0], lm[1], ...), где хотя бы один из элементов больше 20. Поскольку lm[2, 1, 0] является единственным элементом, удовлетворяющим этому условию, я ожидал бы следующий результат -

array([[[15,  7],
        [ 2,  3],
        [ 0,  4]],

       [[ 8, 12],
        [ 6,  5],
        [17, 10]]]

т.е. lm[2] имеет хотя бы один элемент> 20, поэтому он отфильтрован из набора результатов. Как мне этого добиться?

1 Ответ

2 голосов
/ 21 апреля 2020

Два способа сделать это с np.all и np.any с axis arg -

In [14]: lm[(lm<=20).all(axis=(1,2))]
Out[14]: 
array([[[15,  7],
        [ 2,  3],
        [ 0,  4]],

       [[ 8, 12],
        [ 6,  5],
        [17, 10]]])

In [15]: lm[~(lm>20).any(axis=(1,2))]
Out[15]: 
array([[[15,  7],
        [ 2,  3],
        [ 0,  4]],

       [[ 8, 12],
        [ 6,  5],
        [17, 10]]])

Чтобы сделать его обобщенным c чтобы ndarrays работал по двум последним осям, используйте axis=(-2,-1).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...