Вы правы, что в случае (N, 3, 3) * (N, 3, k)
вы не можете использовать np.dot
напрямую, потому что результат будет (N, 3, N, k)
. Вам фактически придется извлечь N
диагональные элементы из осей 0 и 2, но это много ненужных вычислений. Случай (N, 3, 3) * (1, 3, k)
может быть решен с помощью np.dot
, если после применения применить squeeze
для удаления ненужной третьей оси: result = a.dot(b).squeeze()
.
Хорошая новость в том, что вам не нужен np.dot
, чтобы получить точечное произведение. Вот три варианта:
Проще говоря, используйте оператор @
, эквивалентный np.matmul
, который требует, чтобы ведущие измерения передавались вместе:
a @ b
np.matmul(a, b)
Если ваши матрицы находятся не в двух последних измерениях, вы можете транспонировать их в. Это может быть неэффективно, потому что это, вероятно, скопирует данные. В этих случаях читайте дальше.
Другое популярное решение - использовать np.einsum
для явного указания совпадающих осей и осей для суммирования:
np.einsum('ijk,ikl->ijl', a, b)
Вещание позаботится как о i = N
, так и i = 1
случаях.
И, конечно, вы всегда можете взять сумму-результат вручную, используя *
/ np.multiply
и np.sum
:
(a[..., None] * b[:, None, ...]).sum(axis=2)