Как я могу просто отфильтровать массив 3d numpy по значениям его 1-го столбца? - PullRequest
1 голос
/ 09 мая 2020

Предположим, у меня есть массив 3D numpy вроде этого:

 data = np.array([[[1,2,3,4],[1,2.5,3,5]],
                 [[116,230,450,430],[80,100,300,320]],
                 [[60,100,120,80],[50,80,100,90]]])

Как я могу просто извлечь из него массив 3D numpy той же формы с условием на оси 0, например, выбрав те "строки", для которых ось 0 <3? Наивный способ: </p>

data[data[0]<3]

Но это не удается:

IndexError: логический индекс не соответствует индексированному массиву по измерению 0; размер равен 3, но соответствующий логический размер равен 2

1 Ответ

1 голос
/ 09 мая 2020

См. Мой комментарий выше, но исходя из ваших данных я предполагаю, что вам нужны строки с любыми значениями меньше 3. Если это так, вы можете сделать:

data[(data<3).any(axis=2)]
>>> array([[1. , 2. , 3. , 4. ],
           [1. , 2.5, 3. , 5. ]])

EDIT1:

Решение может быть достигнуто с помощью транспонирования для согласования размеров оси:

data.T[(data[0]<3).any(axis=0).T].T
>>> array([[[  1. ,   2. ],
            [  1. ,   2.5]],

           [[116. , 230. ],
            [ 80. , 100. ]],

           [[ 60. , 100. ],
            [ 50. ,  80. ]]])

EDIT2:

Другой метод, не требующий транспонирования. Чтобы применить маску (data[0]<3).any(axis=0) к исходному массиву данных, формы осей должны совпадать. Форма маски - (4,) и data.shape = (3, 2, 4), поэтому нам нужно применить маску к последней оси как:

data[..., (data[0]<3).any(axis=0)]
>>> array([[[  1. ,   2. ],
            [  1. ,   2.5]],

           [[116. , 230. ],
            [ 80. , 100. ]],

           [[ 60. , 100. ],
            [ 50. ,  80. ]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...