Условно дифференцируемая сумма, основанная на индексе - PullRequest
1 голос
/ 19 сентября 2019

У меня есть массив idx, например [0, 1, 0, 2, 3, 1], и еще один массив 2d data, например, следующий:

[[0,  1,  2],
 [3,  4,  5],
 [6,  7,  8],
 [9,  10, 11],
 [12, 13, 14],
 [15, 16, 17]]

Я хочу, чтобы мой вывод был 4x3, где 4 - максимумidx и 3 - это размер элемента (data.shape[1]), а в выходных данных каждый элемент представляет собой сумму элементов с соответствующим индексом в idx.Тогда вывод в этом примере будет:

[[6,  8,  10],
 [18, 20, 22],
 [9,  10, 11],
 [12, 13, 14]]

Я могу сделать это с помощью итерации по range(3) и создания маски для данных и суммирования их, но она не дифференцируема (я полагаю).Есть ли какие-либо функции в Pytorch для этой цели?что-то вроде scatter().

Обновление: Кажется, я ищу что-то с именем scatter sum , которое реализовано в этом хранилище.

1 Ответ

1 голос
/ 19 сентября 2019

Вы ищете index_add_:

import torch

x = torch.tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.],
        [12., 13., 14.],
        [15., 16., 17.]], dtype=torch.float)
idx = torch.tensor([0, 1, 0, 2, 3, 1], dtype=torch.long)  # note the dtype here, must be "long"
# init the sums to zero
y = torch.zeros((idx.max()+1, x.shape[1]), dtype=x.dtype)

# do the magic
y.index_add_(0, idx, x)

Дает желаемый результат

tensor([[ 6.,  8., 10.],
        [18., 20., 22.],
        [ 9., 10., 11.],
        [12., 13., 14.]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...