numpy
имеет take_along_axis
, что делает то, что вы описываете, плюс оно также позволяет вам выбрать ось.
Пример:
>>> a = np.arange(24).reshape(2,3,4)
>>> i = np.random.randint(0,4,(2,3,5))
>>> i
array([[[3, 3, 0, 1, 3],
[3, 1, 0, 3, 3],
[3, 2, 0, 3, 3]],
[[2, 3, 0, 0, 0],
[1, 1, 3, 1, 2],
[1, 3, 0, 0, 2]]])
>>> np.take_along_axis(a, i, -1)
array([[[ 3, 3, 0, 1, 3],
[ 7, 5, 4, 7, 7],
[11, 10, 8, 11, 11]],
[[14, 15, 12, 12, 12],
[17, 17, 19, 17, 18],
[21, 23, 20, 20, 22]]])