Из-за отсутствия np.einsum
и np.newaxis
в numba, я хотел бы получить удобный способ выполнить np.expand_dims
для нескольких осей, в то же время имея возможность numba.njit
над функцией. Два решения, которые я не могу получить njit
'ed:
def expand_dims(arr, axes):
for ax in axes:
arr = np.expand_dims(arr, ax)
return arr
def expand_dims2(arr, axes):
shape_list = list(arr.shape)
for ax in axes:
shape_list.insert(ax, 1)
return arr.reshape(tuple(shape_list))
с axes
, являющимся итерацией осевых индексов, которые должны быть созданы.
Есть ли хороший способ сделать это, имея возможность njit
функции? В противном случае код очень быстро запутывается.