как нарезать 2d массив на основе индексов, заданных в другом 2d массиве - PullRequest
2 голосов
/ 24 апреля 2020

У меня есть массив MxN с именем A, в котором хранятся нужные мне данные. У меня есть еще M x N2 массив B, в котором хранятся индексы массива, и N2<N. В каждой строке B хранятся индексы элементов, которые я хочу получить из A для этой строки. Например, следующий код работает для меня:

A_reduced = np.zeros((M,N2))
for i in range(M):
    A_reduced[i,:] = A[i,B[i,:]]

Существуют ли какие-либо «векторизованные» способы извлечения нужных элементов из A на основе B вместо циклического прохождения каждой строки?

Ответы [ 2 ]

1 голос
/ 24 апреля 2020
In [203]: A = np.arange(12).reshape(3,4)                                                               
In [204]: B = np.array([[0,2],[1,3],[3,0]])   

Ваша итерация строки:

In [207]: A_reduced = np.zeros((3,2),int)                                                              
In [208]: for i in range(3): 
     ...:     A_reduced[i,:] = A[i, B[i,:]] 
     ...:                                                                                              
In [209]: A_reduced                                                                                    
Out[209]: 
array([[ 0,  2],
       [ 5,  7],
       [11,  8]])

'Векторизованная' версия:

In [210]: A[np.arange(3)[:,None], B]                                                                   
Out[210]: 
array([[ 0,  2],
       [ 5,  7],
       [11,  8]])

и усовершенствованная с помощью новой функции sh:

In [212]: np.take_along_axis(A,B,axis=1)                                                               
Out[212]: 
array([[ 0,  2],
       [ 5,  7],
       [11,  8]])
1 голос
/ 24 апреля 2020

Вы можете использовать индексирование массива и использовать изменение формы:

# set up M=N=4, N2=2
a = np.arange(16).reshape(4,4)
b = np.array([[1,2],[0,1],[2,3],[1,3]])

row_idx = np.repeat(np.arange(b.shape[0]),b.shape[1])
col_idx = b.ravel()

# output:
a[row_idx, col_idx].reshape(b.shape)

Вывод:

array([[ 1,  2],
       [ 4,  5],
       [10, 11],
       [13, 15]])

Обновление : другое похожее решение

row_idx = np.repeat(np.arange(b.shape[0]),b.shape[1]).reshape(b.shape)

# output
a[row_idx,b]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...