Итак, выходные данные моей сети выглядят следующим образом:
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
Это форма [8, 24, 2]
Теперь 8 - это размер моего пакета. И я хотел бы получить точку данных из каждого пакета в следующих местах:
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
Итак, 24-е значение из первого пакета, 10-е значение из второго пакета и т. Д.
Теперь у меня проблемы с синтаксисом. Я пробовал
torch.gather(output, 0, index)
Но мне все время говорят, что мои размеры не совпадают. И пробуя
output[ : ,index]
Просто получаю значения по всем индексам для каждой партии. Каким здесь будет правильный синтаксис, чтобы получить эти значения?