Я могу придумать следующий трюк, который может сработать для вас.
Поскольку у нас есть два тензора с разным количеством строк (n и m), сначала мы преобразуем их в одну и ту же форму (m x n x 2
), а затем вычтите. Если две строки совпадают, то после вычитания вся строка будет равна нулю. Затем нам нужно определить индексы этих строк.
n = a.shape[0] # 3
m = b.shape[0] # 2
_a = a.unsqueeze(0).repeat(m, 1, 1) # m x n x 2
_b = b.unsqueeze(1).repeat(1, n, 1) # m x n x 2
match = (_a - _b).sum(-1) # m x n
indices = (match == 0).nonzero()
if indices.nelement() > 0: # empty tensor check
row_indices = indices[:, 1]
else:
row_indices = []
print(row_indices)
Пример ввода / вывода
Пример 1
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
tensor([0, 2])
Пример 2
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 7]])
tensor([2])
Пример 3
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 5], [8, 9]])
tensor([0])
Пример 4
a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 5], [8, 9]])
[]