Ниже приведены две разные альтернативы. Первый использует ndarray.sum
и NumPy индексирование целочисленных массивов . Второй вариант использует np.einsum
.
def using_sum(T):
total = T.sum(axis=1).sum(axis=0)
m = np.arange(T.shape[0])
trace = T[m, m].sum(axis=0)
return total - trace
def using_einsum(T):
return np.einsum('mnijk->ijk', T) - np.einsum('nnijk->ijk', T)
Первый аргумент np.einsum
указывает индексы суммирования.
'mnijk->ijk'
означает, что T
имеет нижние индексы mnijk
, и после суммирования остаются только нижние индексы ijk
. Поэтому суммирование выполняется по подпискам m
и n
. Это делает
np.einsum('mnijk->ijk', T)[i,j,k]
равно np.sum(T[:,:,i,j,k])
, но вычисляет весь массив за одно векторизованное вычисление.
Аналогично, 'nnijk->ijk'
сообщает np.einsum
, что у T
есть подписки nnijk
, и снова только подписчики ijk
выживают при суммировании. Поэтому суммирование закончено n
. Поскольку n
повторяется, суммирование по n
вычисляет след.
Мне нравится np.einsum
, потому что он сообщает о намерениях вычислений
сжато. Но также хорошо знать, как работает using_sum
, так как
он использует фундаментальные операции NumPy. Это хороший пример того, как вложенные циклы
можно избежать с помощью методов NumPy, которые работают с целыми массивами.
Вот perfplot , сравнивающий производительность orig
против using_sum
и using_einsum
в зависимости от n
, где T
имеет форму (10, 10, n, n, n)
:
import perfplot
import numpy as np
def orig(T):
_, _, nx, ny, nz = T.shape
r = np.zeros((nx, ny, nz))
for i in range(nx):
for j in range(ny):
for k in range(nz):
r[i,j,k] = np.sum(T[:,:,i,j,k])-np.trace(T[:,:,i,j,k])
return r
def using_einsum(T):
r = np.einsum('mnijk->ijk', T) - np.einsum('nnijk->ijk', T)
return r
def using_sum(T):
total = T.sum(axis=1).sum(axis=0)
m = np.arange(T.shape[0])
trace = T[m,m].sum(axis=0)
return total - trace
def make_T(n):
return np.random.random((10,10,n,n,n))
perfplot.show(
setup=make_T,
kernels=[orig, using_sum, using_einsum],
n_range=range(2, 80, 3),
xlabel='n')
perfplot.show
также проверяет, что значения, возвращаемые orig
, using_sum
и using_einsum
равны.