Вы можете достичь этого с помощью torch.bmm()
и некоторых torch.squeeze()
/ torch.unsqueeze()
.
Мне лично больше нравятся более общие torch.einsum()
(которые я считаю более читабельными):
import torch
import numpy as np
A = torch.from_numpy(np.array([[[1, 10, 100], [2, 20, 200], [3, 30, 300]],
[[4, 40, 400], [5, 50, 500], [6, 60, 600]]]))
B = torch.from_numpy(np.array([[ 1, 2, 3],
[-1, -2, -3]]))
AB = torch.einsum("nbh,nb->nh", (A, B))
print(AB)
# tensor([[ 14, 140, 1400],
# [ -32, -320, -3200]])