Как отсортировать тензоры pytorch по заданному значению ключа c? - PullRequest
1 голос
/ 13 марта 2020

Я новичок в Pytorch. Учитывая тензорный набор, мне нужно отсортировать эти тензоры по значению ключа. Например,

A = 
[[0.9133, 0.5071, 0.6222, 3.],
 [0.5951, 0.9315, 0.6548, 1.],
 [0.7704, 0.0720, 0.0330, 2.]]

Мой ожидаемый результат после сортировки:

A' = 
[[0.5951, 0.9315, 0.6548, 1.],
 [0.7704, 0.0720, 0.0330, 2.],
 [0.9133, 0.5071, 0.6222, 3.]]

Я пытался использовать функцию sorted в python, но это было время. потребляя в моем тренировочном процессе. Как добиться этого более эффективно? Спасибо!

1 Ответ

1 голос
/ 13 марта 2020
%%timeit -r 10 -n 10
A[A[:,-1].argsort()]

38.6 µs ± 23 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

%%timeit -r 10 -n 10
sorted(A, key = lambda x: x[-1])

69.6 µs ± 34.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

Оба выводят

tensor([[0.5951, 0.9315, 0.6548, 1.0000],
        [0.7704, 0.0720, 0.0330, 2.0000],
        [0.9133, 0.5071, 0.6222, 3.0000]])

Тогда есть

%%timeit -r 10 -n 10
a, b = torch.sort(A, dim=-2)

The slowest run took 8.45 times longer than the fastest. This could mean that an intermediate result is being cached.
14.3 µs ± 18.1 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

с a в качестве отсортированного тензора и b в качестве индексов

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...