Не знаю об эффективном методе, но вы можете попробовать это:
Метод 1:
Использование torch.cat()
import torch
def multiply(a, b):
x1 = a[0:2, 0:2]*b[0,0]
x2 = a[0:2, 2:]*b[0,1]
x3 = a[2:, 0:2]*b[1,0]
x4 = a[2:, 2:]*b[1,1]
return torch.cat((torch.cat((x1, x2), 1), torch.cat((x3, x4), 1)), 0)
a = torch.tensor([[1, 1, 2, 2],[1, 1, 2, 2],[3, 3, 4, 4,],[3, 3, 4, 4]])
b = torch.tensor([[1, 0],[0, 0]])
print(multiply(a, b))
вывод:
tensor([[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
Способ 2:
Использование torch.nn.functional.pad()
import torch.nn.functional as F
import torch
def multiply(a, b):
b = F.pad(input=b, pad=(1, 1, 1, 1), mode='constant', value=0)
b[0,0] = 1
b[0,1] = 1
b[1,0] = 1
return a*b
a = torch.tensor([[1, 1, 2, 2],[1, 1, 2, 2],[3, 3, 4, 4,],[3, 3, 4, 4]])
b = torch.tensor([[1, 0],[0, 0]])
print(multiply(a, b))
вывод:
tensor([[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])