Да, вы можете напрямую нарезать его, используя индекс, а затем использовать torch.unsqueeze()
до повысить 2D-тензор до 3D:
# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
...: sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
В качестве альтернативы, если вам нужен еще более краткий код и вы не хотите использовать torch.unsqueeze()
, используйте:
In [11]: for idx in idx_list:
...: sampled_tensor = tensor[:, [idx], :]
...: print(sampled_tensor.shape)
...:
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
Примечание: нет необходимости использовать цикл for
, если вы хотите сделать это нарезку только для одного idx
из idx_list