Да, я считаю, что это работает.
import torch
S, N, H = 9, 7, 4
a = torch.randn(S, N, H)
# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(0, S, (N,))
i = torch.tensor(range(0,7))
res = torch.zeros(N, H)
res = a[lens, i, :]
print(res)
А почему вы сделали объектив 1 из S + 1, а затем сделали lens[i]-1
? Я просто изменил его, чтобы объектив был 0 от S для удобства. Однако если вам нужно, чтобы объектив был 1 от S + 1, вы можете изменить
res = a[lens, i, :]
на
res = a[lens-1, i, :]