Вот тот, который хорошо масштабируется для общих ndarrays -
def maxindex(a, b, fillna=-1):
sidx = a.argsort(-1)
m = np.isin(sidx,b)
idx = m.shape[-1] - m[...,::-1].argmax(-1) - 1
out = np.take_along_axis(sidx,idx[...,None],axis=-1).squeeze()
return np.where(m.any(-1), out, fillna)
Образцы прогонов -
In [83]: a
Out[83]:
array([[ 1, 4, 6, 2, 5],
[ 3, 2, 7, 12, 1],
[ 8, 5, 3, 1, 4],
[ 6, 10, 2, 4, 9]])
In [84]: b
Out[84]: array([0, 1, 4])
In [85]: maxindex(a, b) # all rows
Out[85]: array([4, 0, 0, 1])
In [86]: maxindex(a[1], b) # second row
Out[86]: array([0])
3D-кейс -
In [105]: a
Out[105]:
array([[[ 1, 4, 6, 2, 5],
[ 3, 2, 7, 12, 1],
[ 8, 5, 3, 1, 4],
[ 6, 10, 2, 4, 9]],
[[ 1, 4, 6, 2, 5],
[ 3, 2, 7, 12, 1],
[ 8, 5, 3, 1, 4],
[ 6, 10, 2, 4, 9]]])
In [106]: maxindex(a, b)
Out[106]:
array([[4, 0, 0, 1],
[4, 0, 0, 1]])