numba и numpy.expand_dims - PullRequest
       7

numba и numpy.expand_dims

0 голосов
/ 05 марта 2019

Я переписываю некоторые из своих функций, чтобы они подходили для Numba. Теперь у меня есть функция, которую я вызываю несколько раз в моем скрипте с входными массивами разных измерений.

def FormHistMatrix2(x,Whc,Lm):
    if x.ndim == 1:
       x = np.expand_dims(x,axis=1)
    [N,Ncells] = x.shape

Это начало моей функции, и Нумба выдает следующую ошибку:

TypingError: Cannot unify array(float64, 2d, A) and array(float64, 3d, A) for 'x', defined at C:/Users/DNP_Student_3/Documents/Python Scripts/GCFuncsTests.py (332)

В этом случае 'x' - это двумерный массив, но в других случаях это может быть одномерный массив. Так разве Нумбе не нравится цикл if? Или что здесь происходит?

Ответы [ 2 ]

1 голос
/ 05 марта 2019

В Numba, в отличие от стандартного python, переменная не может изменить свой тип во время выполнения функции. Вы должны иметь возможность присвоить результат вызова np.expand_dims другой переменной, и он будет работать. Это нормально, если иногда x равен 1d, а иногда - 2d, если есть согласованность типов всех переменных при выполнении функции.

0 голосов
/ 06 марта 2019

То, что сказал ДжошАдель, в целом верно, но проблема в этом случае в том, что вам нужна другая реализация / специализация вашей функции в зависимости от типа ввода.

Numba имеет @generated_jit -декоратор для этого случая.

В вашем случае вам нужно написать специальную функцию expand-dims, которая зависит от размеров входного массива:

import numba as nb
@nb.generated_jit(nopython=True)
def nb_expander(x):
    if x.ndim == 1:
        return lambda x: np.expand_dims(x, axis=1)
    else:
        return lambda x: x

Эту функцию необходимо вызывать из другой функции:

@nb.njit
def FormHistMatrix2(x, Whc, Lm):
    x = nb_expander(x)
    [N, Ncells] = x.shape

Теперь это будет работать для x размеров 1 и 2. Для x.ndim==3 вам также необходимо реализовать аналогичный метод для фигуры.

...