Numpy сумма за повторяющиеся записи в индексном массиве - PullRequest
0 голосов
/ 02 марта 2020

Учитывая numpy ndarray A и целочисленный массив I, той же формы, с наибольшим значением imax и массивом B = np.zeros(imax), который мы можем сделать B[I] = A. Однако, если I имеет повторные записи, выполняется последнее назначение. Мне нужно сделать это при суммировании по повторяющимся записям, например,

For i in range(A.size):
    B[I.ravel()[i]] += A.ravel()[i]

Есть ли хороший способ сделать это в numpy?

Например, я хочу, чтобы это поведение (но ни =, ни += не работало так)

A = np.array((1,2,5,9))
I = np.array((0,1,2,0),dtype=int)
B = np.zeros(3)
B[I] += A
print(B)
>>> array([10,2,5])

Здесь мы видим 1+9=10 в первой записи.

1 Ответ

2 голосов
/ 03 марта 2020
In [1]: A = np.array((1,2,5,9)) 
   ...: I = np.array((0,1,2,0),dtype=int) 
   ...: B = np.zeros(3) 
   ...: B[I] += A                                                                                                          
In [2]: B                                                                                                                  
Out[2]: array([9., 2., 5.])

Это буферизованное решение, отличающееся от итеративного:

In [3]: B = np.zeros(3)                                                                                                    
In [4]: for i,a in zip(I,A): 
   ...:     B[i] += a 
   ...:                                                                                                                    
In [5]: B                                                                                                                  
Out[5]: array([10.,  2.,  5.])

Небуферизованное решение с использованием ufunc.at:

In [6]: B = np.zeros(3)                                                                                                    
In [7]: np.add.at(B, I, A)                                                                                                 
In [8]: B                                                                                                                  
Out[8]: array([10.,  2.,  5.])
...