Вы можете использовать несколько многомерных индексов, но они должны быть одинакового размера или транслируемыми. Например,
# create a (11, 1) range array that broadcasts with indices which is (11, 16)
indices0 = np.expand_dims(np.arange(indices.shape[0]), 1)
A = B[indices0, indices, :]
Поскольку вещание может сбивать с толку, я попытаюсь объяснить это немного. По сути, вы хотите, чтобы indices0
и indices
имели одинаковый размер и представляли собой пары индексов B. Первый индекс будет сохранен в indices0
, а второй будет сохранен в indices
в соответствующих местах. Вещание неявно повторяет столбцы indices0
, чтобы придать ему ту же форму, что и indices
, и часто может быть быстрее, чем создание полноразмерного indices0
.
. В случае, если это поможет, приведем еще несколько подробных примеров, демонстрирующих почему это работает:
import torch
import numpy as np
B = torch.randn(11, 9, 64)
indices = np.random.randint(0,9,(11,16))
# constructing indices0 more verbosely (and slower) for demonstration purposes
a0, a1 = indices.shape
a2 = B.shape[2]
# construct a complete indices0 the slow way, the same size as indices
indices0 = np.empty((a0, a1), dtype=np.int32)
for i in range(a0):
for j in range(a1):
indices0[i,j] = i
# version 1 (nothing complicated happening here but very slow)
A1 = torch.empty(a0, a1, a2, dtype=B.dtype)
for i in range(a0):
for j in range(a1):
A1[i,j,:] = B[indices0[i,j], indices[i,j], :]
# version 2 (using advanced indexing without broadcasting)
A2 = B[indices0, indices, :]
# version 3 (with broadcasting)
# remove repeated columns leaving indices0 as (11, 1) the same state as above
indices0 = indices0[:, :1]
# broadcasting implicitly repeats columns of indices0 to match indices
A3 = B[indices0, indices, :]
# version 4 (your method)
A4 = torch.empty(a0, a1, a2, dtype=B.dtype)
for i in range(a0):
A4[i,:,:] = B[i,indices[i],:]
# compare everything
error = torch.sum(torch.abs(A1 - A2)).item() + \
torch.sum(torch.abs(A2 - A3)).item() + \
torch.sum(torch.abs(A3 - A4)).item()
print('Error:', error)
, который печатает
Error: 0.0
, демонстрируя, что все эти методы эквивалентны.
Кроме того, если вы хотите остаться в Платформа PyTorch и indices
были torch.LongTensor
вместо numpy.ndarray
, тогда вы могли бы использовать
indices0 = torch.arange(indices.shape[0]).unsqueeze(1)
A = B[indices0, indices, :]