У меня есть двумерный тензор целых чисел, и я хотел бы найти индексы строк, столбцы которых содержат любое из указанных значений.
Например, с учетом этого data
тензор
data = torch.randint(10, (10,5))
tensor([[4, 7, 9, 8, 5],
[7, 4, 4, 3, 3],
[4, 9, 7, 7, 0],
[8, 1, 4, 6, 0],
[5, 9, 9, 5, 8],
[9, 3, 7, 6, 5],
[0, 2, 3, 5, 2],
[4, 4, 1, 5, 1],
[9, 8, 3, 7, 1],
[3, 2, 0, 4, 7]])
и эти списки (или, возможно, тензоры) значений
col1_values = [4, 5]
col2_values = [9, 4]
Я хотел бы получить индексы как таковые:
tensor([2, 4, 7])
Я знаю, что могу комбинировать логические маскиодин за другим
filter = ((data[:,0] == 4) + (data[:,0] == 5)) * ((data[:,1] == 9) + (data[:,1] == 4))
indices = filter.nonzero().squeeze()
, который я мог бы автоматизировать с помощью циклов. Но есть ли способ сделать это более эффективно, используя функции pytorch?