Нумба жалуется на печатать - но все типы предоставляются - PullRequest
0 голосов
/ 13 февраля 2020

У меня проблема с набором текста Numba - я прочитал руководство, но в итоге столкнулся с кирпичной стеной.

Данная функция является частью более крупного проекта, хотя она должна выполняться быстро - Python списки исключены, поэтому я решил попробовать Numba. К сожалению, функция не работает в режиме nopython = True, несмотря на тот факт, что, согласно моему пониманию, предоставляются все типы.

Код выглядит так:

from Numba import jit, njit, uint8, int64, typeof

@jit(uint8[:,:,:](int64))
def findWhite(cropped):
    h1 = int64(0)
    for i in cropped:
        for j in i:
            if np.sum(j) == 765:
                h1 = h1 + int64(1)
            else:
                pass
    return h1

также, отдельно:

print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64

В этом случае 'обрезается' - это большая матрица uint8 3D C (понимание файла RGB TIFF - PIL.Image). Может кто-нибудь объяснить новичку Numba ie что я делаю не так?

1 Ответ

1 голос
/ 14 февраля 2020

Рассматривали ли вы использовать Numpy? Часто это хорошее промежуточное звено между Python списками и Numba, что-то вроде:

h1 = (cropped.sum(axis=-1) == 765).sum()

или

h1 = (cropped == 255).all(axis=-1).sum()

Пример кода, который вы предоставляете, не является действительным Numba. Ваша подпись также неверна, поскольку входные данные являются трехмерным массивом, а выходные - целым числом, вероятно, оно должно быть следующим:

@njit(int64(uint8[:,:,:]))

Циклирование массива, как и вы, не является допустимым кодом. Точный перевод вашего кода будет выглядеть примерно так:

@njit(int64(uint8[:,:,:]))
def findWhite(cropped):

    h1 = int64(0)    
    ys, xs, n_bands = cropped.shape

    for i in range(ys):
        for j in range(xs):
            if cropped[i, j, :].sum() == 765:
                h1 += 1

    return h1

Но это не очень быстро и не побеждает Numpy на моей машине. С Numba можно явно l oop для каждого элемента в массиве, это уже намного быстрее:

@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):

    h1 = int64(0)    
    ys, xs, zs = cropped.shape

    for i in range(ys):
        for j in range(xs):

            incr = 1
            for k in range(zs):

                if cropped[i, j, k] != 255:
                    incr = 0
                    break

            h1 += incr

    return h1

Для массива 5000x5000x3 это результат для меня:

Numpy (h1 = (cropped == 255).all(axis=-1).sum()):

427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

findWhite:

612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

findWhite_numba:

31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Преимущество метода Numpy заключается в том, что оно обобщает на любое количество измерений.

...