Скажем, у меня есть тензор pytorch tensor([3,5,7,3,9,3,0]). Я хотел бы извлечь индексы, в которых встречается 3, т.е. tensor([0,3,5]). Есть ли для этого встроенная функция?
tensor([3,5,7,3,9,3,0])
3
tensor([0,3,5])
t = torch.Tensor([1, 2, 3 , 2 , 5]) print ((t == 2).nonzero())
ненулевое значение печатает все ненулевые положения тензора резака https://pytorch.org/docs/master/generated/torch.nonzero.html
Для этого есть специальная функция :
torch.where(my_tensor == the_number)