PyTorch тензорная расширенная индексация - PullRequest
3 голосов
/ 08 апреля 2020

Скажем, у меня есть одна матрица и один вектор следующим образом:

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

есть ли способ нарезать его x[y] так что результат:

res = [1, 6, 8]

так что в основном я возьмите первый элемент y и возьмите элемент в x, который соответствует первой строке и столбцу элементов.

Cheers

Ответы [ 2 ]

3 голосов
/ 08 апреля 2020

Вы можете указать соответствующий индекс строки как:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])
2 голосов
/ 08 апреля 2020

Расширенное индексирование в pytorch работает так же, как и NumPy's, то есть массивы индексации передаются вместе по осям. Так что вы могли бы сделать, как в ответе FBruzzesi.

Хотя аналогично np.take_along_axis, в pytorch у вас также есть torch.gather, чтобы принимать значения по заданной c оси:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...