Индексирование элементов max в многомерном тензоре в PyTorch - PullRequest
0 голосов
/ 06 января 2019

Я пытаюсь индексировать максимальные элементы по последнему измерению в многомерном тензоре. Например, скажем, у меня есть тензор

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

Здесь idx хранит максимальные индексы, которые могут выглядеть примерно так:

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

Я хочу иметь возможность получить доступ к этим индексам и назначить на них другой тензор. Это означает, что я хочу быть в состоянии сделать

B = torch.new_zeros(A.size())
B[idx] = A[idx]

где B везде 0, за исключением случаев, когда A максимально вдоль последнего измерения. То есть B должен хранить

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

Это оказывается намного сложнее, чем я ожидал, так как idx не индексирует массив A должным образом. До сих пор мне не удалось найти векторизованное решение для использования idx для индекса A.

Есть ли хороший векторизованный способ сделать это?

Ответы [ 2 ]

0 голосов
/ 06 января 2019

Вы можете использовать torch.meshgrid для создания кортежа индекса:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

Обратите внимание, что вы также можете имитировать meshgrid через (для конкретного случая 3D):

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

Немного больше объяснений:
У нас будут индексы примерно так:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

Исходя из этого, мы хотим перейти к трем индексам (поскольку наш тензор 3D, нам нужно три числа для извлечения каждого элемента). По сути, мы хотим построить сетку в первых двух измерениях, как показано ниже. (И именно поэтому мы используем сетку).

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

В приведенных выше 5 строках первые два числа в индексах - это, в основном, сетка, которую мы строим с использованием meshgrid, а третье число происходит от idx.

т.е. первые два числа образуют сетку.

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)
0 голосов
/ 06 января 2019

Гадкий взлом состоит в том, чтобы создать двоичную маску из idx и использовать ее для индексации массивов. Основной код выглядит так:

import torch
torch.manual_seed(0)

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)

Хитрость в том, что torch.arange(A.size(2)) перечисляет возможные значения в idx, а mask отлично от нуля в местах, где они равны idx. Примечания:

  1. Если вы действительно отбрасываете первый вывод torch.max, вы можете использовать torch.argmax.
  2. Я предполагаю, что это минимальный пример более широкой проблемы, но учтите, что вы в настоящее время заново изобретаете torch.nn.functional.max_pool3d с ядром размером (1, 1, 3).
  3. Также следует помнить, что модификация тензоров на месте с маскированным присваиванием может вызвать проблемы с autograd, поэтому вы можете использовать torch.where, как показано здесь .

Я ожидаю, что кто-то придумает более чистое решение (избегая промежуточного выделения массива mask), вероятно, используя torch.index_select, но я не могу заставить его работать прямо сейчас.

...