Вы получаете разные результаты, потому что именно так индексация реализована в Pytorch. Если вы передадите массив в качестве индекса, он будет «распакован». Например:
indices = torch.tensor([[0, 1], [0, 2], [1, 0]])
mask = torch.arange(1,28).reshape(3,3,3)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]],
# [[10, 11, 12],
# [13, 14, 15],
# [16, 17, 18]],
# [[19, 20, 21],
# [22, 23, 24],
# [25, 26, 27]]])
mask[indices.numpy()]
эквивалентно mask[[0, 1], [0, 2], [1, 0]]
, т.е. элементы i-й строки indices.numpy()
используются для выбора элементов mask
вдоль i-й оси , Таким образом, он возвращает tensor([mask[0,0,1], mask[1,2,0]])
, то есть tensor([2, 16])
.
С другой стороны, при передаче тензора в качестве индекса (я не знаю точную причину такого различия между массивами и тензорами для индексации), он не "распаковывается" как массив, и элементы i-й строки тензора indices
используются для выбора элементов mask
вдоль оси-0. То есть mask[indices]
эквивалентен mask[[[0, 1], [0, 2], [1, 0]], :, :]
>>> mask[ind]
tensor([[[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]]],
[[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[19, 20, 21],
[22, 23, 24],
[25, 26, 27]]],
[[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]]]])
, который в основном tensor(mask[[0,1], :, :], mask[[0,2],: ,:], mask[[1,0], :, :])
и имеет форму indices.shape + mask[0,:,:].shape == (3,2,3,3)
. Таким образом, целые «листы» отбираются и укладываются в новые измерения. Обратите внимание, что это не новый тензор, а особое представление mask
. Следовательно, если вы присваиваете mask[indices] = 1
с этим конкретным indices
, то все элементы mask
станут равны 1.