как получить индекс подмассива в pytorch? - PullRequest
2 голосов
/ 27 мая 2020

a и b - тензор факела Нет повторяющихся элементов a shape is [n, 2] like:

[[1,2]
[2,3]
[4,6]
...]

b is [m, 2] как:

[[1,2]
[4,6]
....
]

как чтобы получить индекс b в a, например:

a = [[1,2]
[2,4]
[6,7]
]
b = [[1,2]
[6,7]]

индекс должен быть (0,3), мы можем использовать gpu ,

Ответы [ 2 ]

3 голосов
/ 27 мая 2020

Я могу придумать следующий трюк, который может сработать для вас.

Поскольку у нас есть два тензора с разным количеством строк (n и m), сначала мы преобразуем их в одну и ту же форму (m x n x 2 ), а затем вычтите. Если две строки совпадают, то после вычитания вся строка будет равна нулю. Затем нам нужно определить индексы этих строк.

n = a.shape[0] # 3
m = b.shape[0] # 2
_a = a.unsqueeze(0).repeat(m, 1, 1) # m x n x 2
_b = b.unsqueeze(1).repeat(1, n, 1) # m x n x 2

match = (_a - _b).sum(-1) # m x n
indices = (match == 0).nonzero()
if indices.nelement() > 0: # empty tensor check
    row_indices = indices[:, 1]
else:
    row_indices = []

print(row_indices)

Пример ввода / вывода

Пример 1

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
tensor([0, 2])

Пример 2

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 7]])
tensor([2])

Пример 3

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 5], [8, 9]])
tensor([0])

Пример 4

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 3], [6, 5], [8, 9]])
[]
1 голос
/ 28 мая 2020

Здесь решение @jpp, numpy - почти ваш ответ после этого

Вам просто нужно получить индексы, используя nonzero и сглаживать тензор используя flatten, чтобы получить ожидаемую форму.

a = torch.tensor([[1, 2], [2, 4], [6, 7]])
b = torch.tensor([[1, 2], [6, 7]])
(a[:, None] == b).all(-1).any(-1).nonzero().flatten()
tensor([0, 2])
...