Я думаю, что самое простое (конечно, самое короткое) решение - это einsum
.
import torch
T = torch.randn(100, 20, 400)
M = torch.randn(400, 400)
res = torch.einsum('abc,cd,abd->ab', (T, M, T)).unsqueeze(-1)
В основном это говорит "для всех (a, b, c, d)
в границах, умножить T[a, b, c]
с M[c, d]
и T[a, b, d]
и накапливать его в res[a, b]
".
Поскольку einsum
реализован в терминах базовых строительных блоков, таких как mm
, transpose
и т. д.конечно, это можно было бы развернуть в более «классическое» решение, но сейчас мой мозг подводит меня к этому.