Индексирование многомерного массива с набором индексов из индексационного массива - NumPy / Python - PullRequest
2 голосов
/ 08 июля 2019

У меня есть трехмерный массив a с размерами (6,m,n).У меня также есть 6-D логический массив NumPy b с размерами (20,20,20,20,20,20), который эффективно работает как маска.

Я хотел бы использовать 6 значений в каждом местоположении (m,n) в первом массиве, чтобы получить соответствующее значение во втором массиве.По сути, я сожму массив 3D int в двумерный логический массив.Я думал, что решение будет использовать np.where, но я не думаю, что оно может иметь дело с использованием значений в качестве индексов.

Наивная реализация для этого будет выглядеть примерно так:

for i in range(m):
    for j in range(n):
         new_arr[i,j]=b[tuple(a[:,i,j])]

Есть ли способ реализовать это без использования цикла?

1 Ответ

1 голос
/ 08 июля 2019

Подход № 1

Измените a на 2D, сохранив длину первой оси одинаковой. Преобразуйте каждый таким образом 2D-плоский блок в кортеж и затем индексируйте в b. Это преобразование кортежей приводит к упаковке каждого элемента вдоль первой оси в качестве индексатора , чтобы выбрать элемент каждый из b. Наконец, необходимо изменить форму, чтобы получить 2D вывод. Следовательно, реализация будет выглядеть примерно так -

b[tuple(a.reshape(6,-1))].reshape(m,n)

Или пропустите все эти изменения и просто сделайте -

b[tuple(a)]

Это делает то же самое создание индексатора и решает проблему.

Подход № 2

В качестве альтернативы, мы также можем вычислить сглаженные индексы, а затем индексировать их до сглаженных b и извлечь из них соответствующие логические значения -

b.ravel()[np.ravel_multi_index(a,b.shape)]

Синхронизация большого набора данных -

In [89]: np.random.seed(0)
    ...: m,n = 500,500
    ...: b = np.random.rand(20,20,20,20,20,20)>0.5
    ...: a = np.random.randint(0,20,(6,m,n))

In [90]: %timeit b[tuple(a)]
14.6 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [91]: %timeit b.ravel()[np.ravel_multi_index(a,b.shape)]
7.35 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...