Факел сумма каждой строки, исключая индекс - PullRequest
0 голосов
/ 06 апреля 2020

Учитывая Tensor A формы (N, C) и индекс Tensor Idx формы (N,), я хотел бы суммировать все элементы каждой строки в A, исключая соответствующий индекс столбца в I. Например:

A = torch.tensor([[1,2,3],
                  [4,5,6]])

Idx = torch.tensor([0,2])

#result:
torch.tensor([[5],
              [9]])

Решение с использованием циклов известно.

1 Ответ

1 голос
/ 06 апреля 2020

Вы можете установить исключенные элементы в ноль:

A[range(A.shape[0]),Idx] = 0

и суммировать тензор по строкам:

b = A.sum(dim = 1,keepdim = True ) # b = torch.tensor([[5],  [9]])
...