Numba также может быть опцией
Я был немного удивлен не очень хорошими временами Numexpr, поэтому я попробовал версию Numba. Для больших массивов это можно оптимизировать дальше. (Могут применяться те же принципы, что и для dgemm)
import numpy as np
import numba as nb
import numexpr as ne
@nb.njit(fastmath=True,parallel=True)
def min_pairwise_prod(A,B):
assert A.shape[1]==B.shape[1]
res=np.empty((A.shape[0],B.shape[0]))
for i in nb.prange(A.shape[0]):
for j in range(B.shape[0]):
min_prod=A[i,0]*B[j,0]
for k in range(B.shape[1]):
prod=A[i,k]*B[j,k]
if prod<min_prod:
min_prod=prod
res[i,j]=min_prod
return res
Задержка
A=np.random.rand(300,300)
B=np.random.rand(300,300)
%timeit res_1=min_pairwise_prod(A,B) #parallel=True
5.56 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_1=min_pairwise_prod(A,B) #parallel=False
26 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res_2 = ne.evaluate('min(A3D*B,2)',{'A3D':A[:,None]})
87.7 ms ± 265 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res_3=np.min(A[:,None]*B,axis=2)
110 ms ± 214 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
A=np.random.rand(1000,300)
B=np.random.rand(1000,300)
%timeit res_1=min_pairwise_prod(A,B) #parallel=True
50.6 ms ± 401 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_1=min_pairwise_prod(A,B) #parallel=False
296 ms ± 5.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_2 = ne.evaluate('min(A3D*B,2)',{'A3D':A[:,None]})
992 ms ± 7.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_3=np.min(A[:,None]*B,axis=2)
1.27 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)