Если это нормальный тензор, вы можете использовать torch.nonzero
>>> (dist > 0.5).nonzero()
, который будет возвращать индексы всех элементов, которые больше 0.5
Пример:
>>> dist = torch.rand((6,5))
>>> dist
tensor([[0.7549, 0.0962, 0.3198, 0.6868, 0.8117],
[0.0785, 0.7666, 0.2623, 0.5140, 0.2713],
[0.5768, 0.8160, 0.8654, 0.6978, 0.0138],
[0.8147, 0.1394, 0.3204, 0.0104, 0.2872],
[0.1396, 0.5639, 0.7085, 0.7151, 0.8253],
[0.6115, 0.0214, 0.6033, 0.1403, 0.1977]])
>>> (dist > 0.5).nonzero()
tensor([[0, 0],
[0, 3],
[0, 4],
[1, 1],
[1, 3],
[2, 0],
[2, 1],
[2, 2],
[2, 3],
[3, 0],
[4, 1],
[4, 2],
[4, 3],
[4, 4],
[5, 0],
[5, 2]])