Получить значение индекса переменной в определенном измерении - PullRequest
0 голосов
/ 09 марта 2020

Скажите, если у меня есть тензор, который

value = torch.tensor([

    [[0, 0, 0], [1, 1, 1]],
    [[2, 2, 2], [3, 3, 3]],
])

по существу с формой (2,2,3).

Теперь скажите, если у меня есть index = [1, 0], что означает, что я хочу взять:

# row 1 of [[0, 0, 0], [1, 1, 1]], giving me: [1, 1, 1]
# row 0 of [[2, 2, 2], [3, 3, 3]], giving me: [2, 2, 2]

Итак, итоговый результат:

output = torch.tensor([[1, 1, 1], [2, 2, 2]])

есть ли векторизованный способ достижения этого?

1 Ответ

0 голосов
/ 09 марта 2020

Вы можете использовать расширенную индексацию.
Я не могу найти хороший документ Pytorch по этому поводу, но я считаю, что он работает так же, как numpy, так что вот документ numpy об индексации .

import torch

value = torch.tensor([

    [[0, 0, 0], [1, 1, 1]],
    [[2, 2, 2], [3, 3, 3]],
])

index = [1, 0]
i = range(0,2)

result = value[i, index]
# same as result = value[i, index, :] 

print(result)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...