Есть ли способ условно индексировать трехмерный массив? - PullRequest
2 голосов
/ 15 апреля 2019

Имея массив A с формой (2,6, 60), можно ли индексировать его на основе двоичного массива B формы (6,)?

6 и 60 совершенно произвольны, они просто2D-данные, к которым я хочу получить доступ.

Основная задача, которую я пытаюсь сделать, - это рассчитать два варианта 2D-данных (в данном случае (6,60)) и затем эффективно выбрать те, которые имеют наименьшую общую сумму- отсюда бинарный массив (6,).

Пример: для B = [1,0,1,0,1,0] то, что я хочу получить, равно суммированию

A[1,0,:]
A[0,1,:]
A[1,2,:]
A[0,3,:]
A[1,4,:]
A[0,5,:]

, но я хотел бы сделать это путем прямой индексации, а не цикла for.

Я попытался A[B], A[:,B,:], A[B,:,:] A[:,:,B], но ни один из них не дал желаемой (6,60) матрицы.

import numpy as np
A = np.array([[4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1]])
A = np.atleast_3d(A)
A = np.tile(A, (1,1,60)
B = np.array([1, 0, 1, 0, 1, 0])
A[B]

Ожидаемые результаты - массив (6,60), содержащий элементы из A, как описано выше,получено либо (2,6,60), либо (6,6,60).

Заранее спасибо, Линус

Ответы [ 2 ]

1 голос
/ 15 апреля 2019

Вы можете сгенерировать диапазон индексов, по которым вы хотите перебрать, в вашем случае от 0 до 5:

count = A.shape[1]

indices = np.arange(count)  # np.arange(6) for your particular case

>>> print(indices)
array([0, 1, 2, 3, 4, 5])

И затем вы можете использовать это для расширенного индексирования:

result_array = A[B[indices], indices, :]

Если вы всегда используете полный диапазон от 0 до длины - 1 (т.е. от 0 до 5 в вашем случае) второй оси A в порядке возрастания, вы можете упростить это до:

result_array = A[B, indices, :]
# or the ugly result_array = A[B, np.arange(A.shape[1]), :]

Или даже это, если всегда 6:

result_array = A[B, np.arange(6), :]
0 голосов
/ 15 апреля 2019

Альтернативное решение с использованием np.take_along_axis (от версии 1.15 - документы )

import numpy as np
x = np.arange(2*6*6).reshape((2,6,6))
m = np.zeros(6, int)
m[0] = 1
#example: [1, 0, 0, 0, 0, 0]

np.take_along_axis(x, m[None, :, None], 0)   #add dimensions to mask to match array dimensions

>>array([[[36, 37, 38, 39, 40, 41],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...