Вы можете использовать np.take_along_axis
.
Сначала давайте создадим некоторые данные (вы должны были предоставить воспроизводимый пример ):
>>> N, H, W, C = 10, 20, 30, 3
>>> arr = np.random.randn(N, H, W, C)
>>> indices = np.random.randint(0, N, size=(H, W))
Затем мы будем использовать np.take_along_axis
. Но для этого массив indices
должен иметь ту же форму, что и массив arr
. Итак, мы используем np.newaxis
для вставки оси там, где формы не совпадают.
>>> res = np.take_along_axis(arr, indices[np.newaxis, ..., np.newaxis], axis=0)
Он уже дает полезный вывод, но с одноэлементным размером на первой оси:
>>> res.shape
(1, 20, 30, 3)
Итак, мы можем сжать это:
>>> res = np.squeeze(res)
>>> res.shape
(20, 30, 3)
И в конечном итоге проверить, соответствуют ли данные нам:
>>> np.all(res[0, 0] == arr[indices[0, 0], 0, 0])
True
>>> np.all(res[5, 3] == arr[indices[5, 3], 5, 3])
True