Почему numpy.where гораздо быстрее, чем альтернативы - PullRequest
2 голосов
/ 12 марта 2019

я пытаюсь ускорить следующий код:

import time
import numpy as np
np.random.seed(10)
b=np.random.rand(10000,1000)
def f(a=1):
    tott=0
    for _ in range(a):
        q=np.array(b)
        t1 = time.time()
        for i in range(len(q)):
            for j in range(len(q[0])):
                if q[i][j]>0.5:
                    q[i][j]=1
                else:
                    q[i][j]=-1
        t2=time.time()
        tott+=t2-t1
    print(tott/a)

Как видите, в основном func - это итерация в двойном цикле. Итак, я попытался использовать np.nditer, np.vectorize и map вместо него. Если дает некоторое ускорение (например, 4-5 раз, кроме np.nditer), но! с np.where(q>0.5,1,-1) ускорение почти в 100 раз. Как я могу перебирать массивы так быстро, как это делает np.where? И почему это намного быстрее?

Ответы [ 2 ]

5 голосов
/ 12 марта 2019

Это потому, что ядро ​​numpy реализовано на C. Вы в основном сравниваете скорость C с Python.

Если вы хотите использовать преимущество numpy по скорости, вы должны сделать как можно меньше вызовов в своем коде Python. Если вы используете цикл Python, вы уже проиграли, даже если вы используете только функции numpy в этом цикле. Используйте высокоуровневые функции, предоставляемые numpy (поэтому они поставляют так много специальных функций). Внутренне, он будет использовать гораздо более эффективный (C-) цикл

Вы можете самостоятельно реализовать функцию в C (с циклами) и вызывать ее из Python. Это должно дать сопоставимые скорости.

4 голосов
/ 12 марта 2019

Чтобы ответить на этот вопрос, вы можете получить ту же скорость (ускорение в 100 раз), используя библиотеку numba:

from numba import njit

def f(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

@njit
def f_jit(b):
    q = np.zeros_like(b)

    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            if q[i][j] > 0.5:
                q[i][j] = 1
            else:
                q[i][j] = -1

    return q

Сравните скорость:

Простой Python

%timeit f(b)
592 ms ± 5.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba (как раз вовремя скомпилировано с использованием скорости LLVM ~ C)

%timeit f_jit(b)
5.97 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...