У меня есть большой вектор, который я хотел бы обновить.Я обновлю его, добавив смещение для определенных элементов в векторе.Я указываю вектор индексов, которые я хочу обновить (вызываем вектор индекса ix
), и для каждого индекса я указываю значение, которое я хочу добавить к этому элементу (вызываем вектор значения vals
).Если все элементы вектора индекса уникальны, тогда достаточно следующего кода:
vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,2], dtype=torch.long)
vals = torch.tensor([0.2, 0.5], dtype=torch.float)
vec[ix] += vals
Однако это не работает, если в ix
есть повторяющиеся индексы.Наивный подход для случая повторяющихся индексов заключается в следующем:
for i in range(len(ix)):
vec[ix[i]] += vals[i]
Но это плохо масштабируется - он очень медленный, когда ix
велико.Есть ли более быстрые способы сделать это?Если бы был быстрый способ суммировать все записи в vals
, которые имеют одинаковый индекс в ix
, то решение должно быть простым.
Обновление:
Я нашелодно решение, которое работает довольно хорошо, описано ниже.Я все еще хотел бы обратной связи для лучших решений.
# get unique indices
ix_unique = torch.unique(ix)
# for each unique index, get sum of all vals with that index
vals_unique = torch.stack([
torch.sum(torch.where(ix==i, vals, torch.zeros_like(vals)))
for i in ix_unique
])
# update vec
vec[ix_unique] += vals_unique