значение kth на строку в pytorch? - PullRequest
3 голосов
/ 01 апреля 2020

Учитывая

import torch    
A = torch.rand(9).view((3,3)) # tensor([[0.7455, 0.7736, 0.1772],\n[0.6646, 0.4191, 0.6602],\n[0.0818, 0.8079, 0.6424]])
k = torch.tensor([0,1,0])
A.kthvalue_vectoriezed(k) -> [0.1772,0.6602,0.0818]

Значение Я хотел бы работать с каждым столбцом с различным k.
Не kthvalue, ни topk предлагает такой API. Есть ли векторизованный способ обойти это?
Замечание - k-е значение - это не значение в k-м индексе, а k-й наименьший элемент. Pytorch docs

torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)
Возвращает именованный кортеж (значения, индексы), где значения - это k-й наименьший элемент каждой строки входного тензора в данном измерении тусклый. И индексы - это местоположение индекса каждого найденного элемента.

1 Ответ

1 голос
/ 01 апреля 2020

Предполагая, что вам не нужны индексы в исходной матрице (если вы это делаете, просто используйте причудливую индексацию и для второго возвращаемого значения), вы можете просто отсортировать значения (по последнему индексу по умолчанию) и вернуть соответствующие значения следующим образом:

def kth_smallest(tensor, indices):
    tensor_sorted, _ = torch.sort(tensor)
    return tensor_sorted[torch.arange(len(indices)), indices]

И этот контрольный пример дает вам желаемые значения:

tensor = torch.tensor(
    [[0.7455, 0.7736, 0.1772], [0.6646, 0.4191, 0.6602], [0.0818, 0.8079, 0.6424]]
)

print(kth_smallest(tensor, [0, 1, 0])) # -> [0.1772,0.6602,0.0818]
...