Как динамически индексировать тензор в pytorch? - PullRequest
2 голосов
/ 05 апреля 2019

Например, я получил тензор:

tensor = torch.rand(12, 512, 768)

И я получил индексный список, скажем, что это:

[0,2,3,400,5,32,7,8,321,107,100,511]

Я хочу выбрать 1 элемент из 512 элементов в измерении 2, учитывая список индексов. И тогда размер тензора станет (12, 1, 768).

Есть ли способ сделать это?

Ответы [ 2 ]

3 голосов
/ 05 апреля 2019

Существует также способ использовать PyTorch и избежать цикла, используя индексирование и torch.split:

tensor = torch.rand(12, 512, 768)

# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list) 

# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)

КогдаВы звоните tensor[:, idx_tensor, :], вы получите тензор формы:
(12, len_of_idx_list, 768).
Где второе измерение зависит от вашего числа индексов.

Использование torch.split этот тензор разбит на список изтензоры формы: (12, 1, 768).

Итак, наконец, list_of_tensors содержит тензоры формы:

[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]
0 голосов
/ 05 апреля 2019

Да, вы можете напрямую нарезать его, используя индекс, а затем использовать torch.unsqueeze() до повысить 2D-тензор до 3D:

# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]

# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
   ...:     sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
   ...:     print(sampled_tensor.shape)
   ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

В качестве альтернативы, если вам нужен еще более краткий код и вы не хотите использовать torch.unsqueeze(), используйте:

In [11]: for idx in idx_list:
    ...:     sampled_tensor = tensor[:, [idx], :]
    ...:     print(sampled_tensor.shape)
    ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

Примечание: нет необходимости использовать цикл for, если вы хотите сделать это нарезку только для одного idx из idx_list

...