In [331]: A=np.random.rand(100,200,300)
In [332]: B=A
Предлагаемый einsum
, работающий напрямую с
C[i,j,k] = np.dot(A[i,k,:], B[j,k,:]
выражение:
In [333]: np.einsum( 'ikm, jkm-> ijk', A, B).shape
Out[333]: (100, 100, 200)
In [334]: timeit np.einsum( 'ikm, jkm-> ijk', A, B).shape
800 ms ± 25.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
matmul
делает dot
в последних двух измерениях и рассматривает одно из первых в качестве пакета. В вашем случае «k» - это размер пакета, а «m» - тот, который должен подчиняться правилу last A and 2nd to the last of B
. Итак, переписываем ikm,jkm...
, чтобы соответствовать, и транспонируем A
и B
соответственно:
In [335]: np.einsum('kim,kmj->kij', A.transpose(1,0,2), B.transpose(1,2,0)).shape
Out[335]: (200, 100, 100)
In [336]: timeit np.einsum('kim,kmj->kij',A.transpose(1,0,2), B.transpose(1,2,0)).shape
774 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Не большая разница в производительности. Но теперь используйте matmul
:
In [337]: (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
Out[337]: (100, 100, 200)
In [338]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
64.4 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
и убедитесь, что значения совпадают (хотя чаще всего, если формы совпадают, значения соответствуют).
In [339]: np.allclose((A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0),np.einsum( 'ikm, jkm->
...: ijk', A, B))
Out[339]: True
Я не буду пытаться измерить использование памяти, но улучшение времени говорит о том, что оно тоже лучше.
В некоторых случаях einsum
оптимизирован для использования matmul
. Здесь это не похоже на случай, хотя мы могли бы поиграть с его параметрами Я немного удивлен, что matmul
делает намного лучше.
===
Я смутно припоминаю еще одну СО о matmul
сокращении, когда два массива - одно и то же, A@A
. Я использовал B=A
в этих тестах.
In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [352]: B2=np.random.rand(100,200,300)
In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape
97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Но это имело только скромное значение.
In [356]: np.__version__
Out[356]: '1.16.4'
Мой BLAS и т. Д. Является стандартным Linux, ничего особенного.