Я пытаюсь индексировать максимальные элементы по последнему измерению в многомерном тензоре. Например, скажем, у меня есть тензор
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
Здесь idx хранит максимальные индексы, которые могут выглядеть примерно так:
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
Я хочу иметь возможность получить доступ к этим индексам и назначить на них другой тензор. Это означает, что я хочу быть в состоянии сделать
B = torch.new_zeros(A.size())
B[idx] = A[idx]
где B везде 0, за исключением случаев, когда A максимально вдоль последнего измерения. То есть B должен хранить
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
Это оказывается намного сложнее, чем я ожидал, так как idx не индексирует массив A должным образом. До сих пор мне не удалось найти векторизованное решение для использования idx для индекса A.
Есть ли хороший векторизованный способ сделать это?