Учитывая Tensor A формы (N, C) и индекс Tensor Idx формы (N,), я хотел бы суммировать все элементы каждой строки в A, исключая соответствующий индекс столбца в I. Например:
A = torch.tensor([[1,2,3],
[4,5,6]])
Idx = torch.tensor([0,2])
#result:
torch.tensor([[5],
[9]])
Решение с использованием циклов известно.