Возможно, это можно написать более элегантно с помощью np.einsum()
:
import numpy as np
n = 10
a = 2 * np.ones((n, n, 3))
b = 3 * np.ones((n, n, 3))
s = 0
for i in range(n):
for j in range(n):
s += a * b[i, j]
print(s.shape)
# (10, 10, 3)
ss = a * np.einsum('ijk->k', b)
print(ss.shape)
# (10, 10, 3)
print(np.all(s == ss))
# True
или даже просто np.sum()
:
sss = a * np.sum(b, axis=(0, 1))
print(sss.shape)
# (10, 10, 3)
print(np.all(s == sss))
# True
, но np.einsum()
кажется быстрее:
n = 100
a = 2 * np.ones((n, n, 3))
b = 3 * np.ones((n, n, 3))
%timeit f_with_loops(a, b)
# 1 loop, best of 3: 787 ms per loop
%timeit a * np.einsum('ijk->k', b)
# 10000 loops, best of 3: 121 µs per loop
%timeit a * np.sum(b, axis=(0, 1))
# 1000 loops, best of 3: 254 µs per loop