Пакетное индексирование Pytorch - PullRequest
1 голос
/ 08 мая 2020

Итак, выходные данные моей сети выглядят следующим образом:

output = tensor([[[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.0315, -0.1837],
     [ 0.0318, -0.1828],
     [ 0.0322, -0.1822],
     [ 0.0324, -0.1819],
     [ 0.0327, -0.1817],
     [ 0.0328, -0.1815],
     [ 0.0330, -0.1815],
     [ 0.0331, -0.1814],
     [ 0.0332, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]]])

Это форма [8, 24, 2] Теперь 8 - это размер моего пакета. И я хотел бы получить точку данных из каждого пакета в следующих местах:

index = tensor([24, 10,  3,  3,  1,  1,  1,  0])

Итак, 24-е значение из первого пакета, 10-е значение из второго пакета и т. Д.

Теперь у меня проблемы с синтаксисом. Я пробовал

torch.gather(output, 0, index)

Но мне все время говорят, что мои размеры не совпадают. И пробуя

output[ : ,index]

Просто получаю значения по всем индексам для каждой партии. Каким здесь будет правильный синтаксис, чтобы получить эти значения?

Ответы [ 2 ]

0 голосов
/ 08 мая 2020

Сначала небольшое примечание: для формы вывода [8, 24, 2] наибольший индекс второго измерения может быть 23, поэтому я изменяю ваши индексы на

index = torch.tensor([23, 10,  3,  3,  1,  1,  1,  0])
output = torch.randn((8,24,2)) # Toy data to represent your output

Самое простое решение - использовать для l oop

data_pts = torch.zeros((8,2)) # Tensor to store desired values

for i,j in enumerate(index):
    data_pts[i, :] = output[i, j, :]

Однако, если вы хотите векторизовать индексацию, вам просто нужны индексы для всех измерений. Например,

data_pts_vectorized = output[range(8), index, :] 

Поскольку ваш вектор индекса в порядке, вы можете сгенерировать индекс первого измерения с помощью range.

Вы можете подтвердить, что оба подхода дают одинаковые результаты

assert(torch.all(data_pts == data_pts_vectorized))
0 голосов
/ 08 мая 2020

Чтобы выбрать только один элемент в пакете, вам необходимо перечислить индексы пакета, что можно легко сделать с помощью torch.arange.

output[torch.arange(output.size(0)), index]

Это по существу создает кортежи между перечисленными тензор и тензор index для доступа к данным, что приводит к индексации output[0, 24], output[1, 10] et c.

...