Вы можете использовать torch.bmm
(https://pytorch.org/docs/stable/torch.html#torch .bmm ). Просто выполните torch.bmm(I,W)
Чтобы проверить результаты:
import torch
N, C, X, Y= 100, 10, 9, 8
i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)
# desired result code
I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
print(torch.allclose(O,o)) # should output True if outputs are same.
РЕДАКТИРОВАТЬ: В идеале, я бы предположил, что использование умножения внутренней матрицы Pytorch является эффективным. Однако вы также можете измерить использование памяти с помощью tracemalloc
(по крайней мере, на процессоре). См. https://discuss.pytorch.org/t/measuring-peak-memory-usage-tracemalloc-for-pytorch/34067 для графического процессора.
import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8
i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
o = torch.bmm(i,w)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())
Вы можете сделать то же самое с другим кодом и увидеть, что реализация bmm
действительно эффективна.
import torch
import tracemalloc
tracemalloc.start()
N, C, X, Y= 100, 10, 9, 8
i = torch.rand(N,C,X)
w = torch.rand(N,X,Y)
I = i.view(N, C, X, 1)
W = w.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
# output is a tuple indicating current memory and peak memory
print(tracemalloc.get_traced_memory())