Индексация многомерного тензора с тензором в PyTorch - PullRequest
0 голосов
/ 30 августа 2018

У меня есть следующий код:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

У меня есть многомерный индекс b, и я хочу использовать его для выбора отдельной ячейки в a. Если бы b не был тензорным, я мог бы сделать:

a[1,1,1,1]

, который возвращает правильную ячейку, но:

a[b]

Не работает, потому что просто выбирает a[1] четыре раза.

Как я могу это сделать? Благодаря

Ответы [ 2 ]

0 голосов
/ 30 августа 2018

Более элегантным (и более простым) решением может быть просто приведение b в виде кортежа:

a[tuple(b)]
Out[10]: tensor(5.)

Мне было любопытно посмотреть, как это работает с "обычным" numpy, и нашел соответствующую статью, объясняющую это довольно хорошо здесь .

0 голосов
/ 30 августа 2018

Вы можете разделить b на 4 с помощью chunk, а затем использовать чанкированный b для индексации нужного элемента:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

Что приятно, так это то, что он может быть легко обобщен для любого измерения a, вам просто нужно сделать количество патронов равным измерению a.

...