Факел найти индексы совпадающих строк в 2-х 2D тензорах - PullRequest
1 голос
/ 12 января 2020

У меня есть два 2D-тензора разной длины, оба являются разными подмножествами одного и того же исходного 2-мерного тензора, и я хотел бы найти все подходящие "строки"
например

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

Я вижу только numpy решения, которые используют dtype в качестве диктов и не работают для pytorch.


Вот как я это делаю в numpy

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)

Ответы [ 2 ]

1 голос
/ 02 марта 2020

Если A и B являются 2D-тензорами, следующий код находит такие индексы, что A[indices] == B. Если несколько индексов удовлетворяют этому условию, возвращается первый найденный индекс. Если не все элементы B присутствуют в A, соответствующий индекс игнорируется.

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])
1 голос
/ 12 января 2020

Этот ответ был опубликован до того, как ОП обновил вопрос другими ограничениями, которые немного изменили проблему.

TL; DR Вы можете сделать что-то вроде этого:

torch.where((A == B).all(dim=1))[0]

Сначала предположим, что у вас есть:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

Мы можем проверить, что A == B возвращает:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

Итак, что мы хотим: строки в котором они все True. Для этого мы можем использовать операцию .all() и указать интересующее измерение, в нашем случае 1:

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

Что вы действительно хотите знать, так это где True s. Для этого мы можем получить первый вывод функции torch.where():

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])
...