Могу ли я извлечь все индексы, соответствующие определенному ключу в тензоре pytorch? - PullRequest
1 голос
/ 21 июня 2020

Скажем, у меня есть тензор pytorch tensor([3,5,7,3,9,3,0]). Я хотел бы извлечь индексы, в которых встречается 3, т.е. tensor([0,3,5]). Есть ли для этого встроенная функция?

Ответы [ 2 ]

0 голосов
/ 21 июня 2020
t = torch.Tensor([1, 2, 3 , 2 , 5])
print ((t == 2).nonzero())

ненулевое значение печатает все ненулевые положения тензора резака https://pytorch.org/docs/master/generated/torch.nonzero.html

0 голосов
/ 21 июня 2020

Для этого есть специальная функция :

   torch.where(my_tensor == the_number)
...