Как выполнить расширенную индексацию в PyTorch? - PullRequest
0 голосов
/ 08 марта 2020

Есть ли способ сделать следующее без зацикливания?

S, N, H = 9, 7, 4

a = torch.randn(S, N, H)

# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(1, S + 1, (N,)) 

res = torch.zeros(N, H)

for i in range(N):
    res[i] = a[lens[i] - 1, i, :]

1 Ответ

2 голосов
/ 08 марта 2020

Да, я считаю, что это работает.

import torch

S, N, H = 9, 7, 4

a = torch.randn(S, N, H)

# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(0, S, (N,)) 
i = torch.tensor(range(0,7))
res = torch.zeros(N, H)

res = a[lens, i, :]
print(res)

А почему вы сделали объектив 1 из S + 1, а затем сделали lens[i]-1? Я просто изменил его, чтобы объектив был 0 от S для удобства. Однако если вам нужно, чтобы объектив был 1 от S + 1, вы можете изменить
res = a[lens, i, :]
на
res = a[lens-1, i, :]

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...