Это можно сделать с помощью трансляции:
X@rotation_z_matrix.transpose(0,2,1)[:, None, ...]
Это дает (на фиктивном наборе данных) тот же ответ, что и у @ Divakar
batch_size = 10
seq_length = 8
n_coordinates = 12
X = np.random.randint(0,10,(batch_size, seq_length, n_coordinates, 3))
rotation_z_matrix = np.random.randint(0,10,(batch_size,3,3))
(X@rotation_z_matrix.transpose(0,2,1)[:, None, ...] == np.einsum('ijk,ilmk->ilmj',rotation_z_matrix,X)).all()
# True
Но по крайней мере для этого примера этозначительно быстрее.
timeit(lambda: np.einsum('ijk,ilmk->ilmj',rotation_z_matrix,X, optimize=True), number=1000)
# 0.1285447319969535
timeit(lambda: np.einsum('ijk,ilmk->ilmj',rotation_z_matrix,X, optimize=False), number=1000)
# 0.07962286799738649
timeit(lambda: X@rotation_z_matrix.transpose(0,2,1)[:, None, ...], number=1000)
# 0.019039910010178573
Обязательно обратите внимание, что установка флага optimize
на самом деле замедляет einsum
.(Это случается довольно часто со мной.)
Обновление: тот же пример, но с данными, преобразованными в float dtype
timeit(lambda: np.einsum('ijk,ilmk->ilmj',rotation_z_matrix,X, optimize=True), number=1000)
# 0.12346570500812959
timeit(lambda: np.einsum('ijk,ilmk->ilmj',rotation_z_matrix,X, optimize=False), number=1000)
# 0.07575376800377853
timeit(lambda: X@rotation_z_matrix.transpose(0,2,1)[:, None, ...], number=1000)
# 0.027829282989841886