Многомерный numpy.expand_dims в numba - PullRequest
0 голосов
/ 05 ноября 2019

Из-за отсутствия np.einsum и np.newaxis в numba, я хотел бы получить удобный способ выполнить np.expand_dims для нескольких осей, в то же время имея возможность numba.njit над функцией. Два решения, которые я не могу получить njit 'ed:

def expand_dims(arr, axes):
    for ax in axes:
        arr = np.expand_dims(arr, ax)
    return arr

def expand_dims2(arr, axes):
    shape_list = list(arr.shape)
    for ax in axes:
        shape_list.insert(ax, 1)
    return arr.reshape(tuple(shape_list))

с axes, являющимся итерацией осевых индексов, которые должны быть созданы.

Есть ли хороший способ сделать это, имея возможность njit функции? В противном случае код очень быстро запутывается.

...