Этот ответ был опубликован до того, как ОП обновил вопрос другими ограничениями, которые немного изменили проблему.
TL; DR Вы можете сделать что-то вроде этого:
torch.where((A == B).all(dim=1))[0]
Сначала предположим, что у вас есть:
import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])
Мы можем проверить, что A == B
возвращает:
>>> A == B
tensor([[ True, True, True],
[ True, False, False],
[ True, True, True]])
Итак, что мы хотим: строки в котором они все True
. Для этого мы можем использовать операцию .all()
и указать интересующее измерение, в нашем случае 1
:
>>> (A == B).all(dim=1)
tensor([ True, False, True])
Что вы действительно хотите знать, так это где True
s. Для этого мы можем получить первый вывод функции torch.where()
:
>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])