У меня есть массив idx
, например [0, 1, 0, 2, 3, 1]
, и еще один массив 2d data
, например, следующий:
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9, 10, 11],
[12, 13, 14],
[15, 16, 17]]
Я хочу, чтобы мой вывод был 4x3
, где 4 - максимумidx
и 3 - это размер элемента (data.shape[1]
), а в выходных данных каждый элемент представляет собой сумму элементов с соответствующим индексом в idx
.Тогда вывод в этом примере будет:
[[6, 8, 10],
[18, 20, 22],
[9, 10, 11],
[12, 13, 14]]
Я могу сделать это с помощью итерации по range(3)
и создания маски для данных и суммирования их, но она не дифференцируема (я полагаю).Есть ли какие-либо функции в Pytorch для этой цели?что-то вроде scatter()
.
Обновление: Кажется, я ищу что-то с именем scatter sum , которое реализовано в этом хранилище.