Предполагая, что вам не нужны индексы в исходной матрице (если вы это делаете, просто используйте причудливую индексацию и для второго возвращаемого значения), вы можете просто отсортировать значения (по последнему индексу по умолчанию) и вернуть соответствующие значения следующим образом:
def kth_smallest(tensor, indices):
tensor_sorted, _ = torch.sort(tensor)
return tensor_sorted[torch.arange(len(indices)), indices]
И этот контрольный пример дает вам желаемые значения:
tensor = torch.tensor(
[[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)
print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]