Как проиндексировать трехмерный тензор с двумерным тензором в pytorch? - PullRequest
0 голосов
/ 04 апреля 2020
import torch
a = torch.rand(5,256,120)
min_values, indices = torch.min(a,dim=0)
aa = torch.zeros(256,120)
for i in range(256):
    for j in range(120):
        aa[i,j] = a[indices[i,j],i,j]

print((aa==min_values).sum()==256*120)

Я хочу знать, как избежать использования for-for l oop для получения значений aa? (Я хочу использовать индексы для выбора элементов в других трехмерных тензорах, поэтому я не могу напрямую использовать значения, возвращаемые по min) *

1 Ответ

1 голос
/ 04 апреля 2020

Вы можете использовать torch.gather

aa = torch.gather(a, 0, indices.unsqueeze(0))

, как описано здесь: Нарезка 4D-тензора с 3D-индексом тензора в PyTorch

...