Вот как вы можете это сделать для этого примера:
import numpy as np
input_ = np.array([[0, 1],
[2, 3]])
subs = np.array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]])
res = subs[input_].transpose((0, 2, 1, 3)).reshape((4, 4))
print(res)
# [[ 0 1 4 5]
# [ 2 3 6 7]
# [ 8 9 12 13]
# [10 11 14 15]]
EDIT:
Более общее решение, поддерживающее большее количество измерений, входных данных и замен с различным количеством измерений:
import numpy as np
def expand_from(input_, subs):
input_= np.asarray(input_)
subs = np.asarray(subs)
# Take from subs according to input
res = subs[input_]
# Input dimensions
in_dims = input_.ndim
# One dimension of subs is for indexing
s_dims = subs.ndim - 1
# Dimensions that correspond to each other on output
num_matched = min(in_dims, s_dims)
matched_dims = [(i, in_dims + i) for i in range(num_matched)]
# Additional dimensions if there are any
if in_dims > s_dims:
extra_dims = list(range(num_matched, in_dims))
else:
extra_dims = list(range(2 * num_matched, in_dims + s_dims))
# Dimensions order permutation
dims_reorder = [d for m in matched_dims for d in m] + extra_dims
# Output final shape
res_shape = ([res.shape[d1] * res.shape[d2] for d1, d2 in matched_dims] +
[res.shape[d] for d in extra_dims])
return res.transpose(dims_reorder).reshape(res_shape)
input_ = np.array([[0, 1],
[2, 3]])
subs = np.array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]])
output = expand_from(input_, subs)
print(output)
# [[ 0 1 4 5]
# [ 2 3 6 7]
# [ 8 9 12 13]
# [10 11 14 15]]