Есть ли в pytorch встроенный метод для извлечения строк с заданными индексами? - PullRequest
1 голос
/ 04 мая 2020

Предположим, у меня есть тензор факела

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

и список

b = [0,2]

Есть ли встроенный метод для извлечения строк 0 и 2 и помещения их в новый тензор:

tensor([[1,2,3],
        [7,8,9]])

В частности, есть ли функция, которая выглядит следующим образом:

extract_rows(a,b) -> c

, где c содержит нужные строки. Конечно, это может быть сделано для l oop, но встроенный метод в целом быстрее.

Обратите внимание, что этот пример является только примером, в списке могут быть десятки индексов, и сотни строк в тензоре.

Ответы [ 2 ]

1 голос
/ 04 мая 2020

взгляните на встроенный в факел метод index_select () . Это было бы полезно для вас. или Вы можете сделать это, используя нарезку.

tensor = [[1,2,3],
            [4,5,6],
            [7,8,9]]

new_tensor = tensor[0::2]
print(new_tensor)

Выход:

[[1, 2, 3], [7, 8, 9]]
0 голосов
/ 04 мая 2020

Просто a[b] будет работать

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])
b = [0,2]
a[b]
tensor([[1, 2, 3],
        [7, 8, 9]])
...