Предположим, у меня есть список индексов и я хочу изменить существующий массив с этим списком.В настоящее время единственный способ сделать это - использовать цикл for следующим образом.Просто интересно, есть ли более быстрый / эффективный способ.
torch.manual_seed(0)
a = torch.randn(5,3)
idx = torch.Tensor([[1,2], [3,2]], dtype=torch.long)
for i,j in idx:
a[i,j] = 1
Я изначально предполагал, что gather
или index_select
пойдут каким-то образом в ответе на этот вопрос, но, глядя на документацию это не похоже на ответ.
В моем конкретном случае a - это 5-мерный вектор, а idx - это вектор Nx5.Таким образом, результат (после подписки с чем-то вроде a[idx]
), я ожидаю, будет иметь форму (N,)
вектора.
Ответ
Благодаря @shai ниже, ответ, который я искалбыло: a[idx.t().chunk(chunks=2,dim=0)]
.Взято из этого ТАКого ответа .