PyTorch: найти индексы тензорных рядов, которые удовлетворяют заданным условиям - PullRequest
0 голосов
/ 19 октября 2019

У меня есть двумерный тензор целых чисел, и я хотел бы найти индексы строк, столбцы которых содержат любое из указанных значений.

Например, с учетом этого data тензор

data = torch.randint(10, (10,5))
tensor([[4, 7, 9, 8, 5],
        [7, 4, 4, 3, 3],
        [4, 9, 7, 7, 0],
        [8, 1, 4, 6, 0],
        [5, 9, 9, 5, 8],
        [9, 3, 7, 6, 5],
        [0, 2, 3, 5, 2],
        [4, 4, 1, 5, 1],
        [9, 8, 3, 7, 1],
        [3, 2, 0, 4, 7]])

и эти списки (или, возможно, тензоры) значений

col1_values = [4, 5]
col2_values = [9, 4]

Я хотел бы получить индексы как таковые:

tensor([2, 4, 7])

Я знаю, что могу комбинировать логические маскиодин за другим

filter = ((data[:,0] == 4) + (data[:,0] == 5)) * ((data[:,1] == 9) + (data[:,1] == 4))
indices = filter.nonzero().squeeze()

, который я мог бы автоматизировать с помощью циклов. Но есть ли способ сделать это более эффективно, используя функции pytorch?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...