Как правильно ускорить с помощью Numba? - PullRequest
0 голосов
/ 23 апреля 2020

В настоящее время я работаю над ускорением для моих python функций.

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)


def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat

Как видите, это простой код, основанный на numpy. Я прочитал большую часть статей о inte rnet, которые, как они сказали, просто помещают @ numba.jit в качестве декоратора перед функцией, и затем я могу использовать Numba, чтобы ускорить мой код.

Вот тест, который я сделал.

u = np.random.randn(10000)
v = np.random.randn(10000)
lon1 = np.random.uniform(-99,-96,10000)
lat1 = np.random.uniform( 23, 25,10000)
print(u)
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

5,61 с ± 58,7 мс на л oop (среднее ± стандартное отклонение из 7 прогонов, 1 л oop каждый)

Добавить декоратор Numba

@numba.njit()
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit()
def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

7,76 с ± 64,9 мс на л oop (среднее ± стандартное отклонение из 7 прогонов, 1 л oop каждый)

Как вы можете видеть выше , скорость вычислений в моем случае Numba ниже, чем в моем чистом python случае. Может ли кто-нибудь помочь мне решить эту проблему?

ps: версия numba
llvmlite 0.32.0rc1
numba 0.49.0rc2

------ Вычислительный тест по поводу ответа макроэкономиста. ------

Согласно его ответу, даже Numba теперь достаточно умен, если мы хотим, чтобы код был украшен Numba, лучше использовать простой "Fortran" / "C "тип стилей. Ниже представлено сравнение времени вычислений между различными методами, о которых я думал.

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

54 с ± 485 мс на л oop (среднее ± стандартное отклонение для 7 прогонов, 1 л oop каждый)

@numba.jit(nogil=True)
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.jit(nogil=True)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

1 мин 21 с ± 815 мс на л oop (среднее ± стандартное отклонение из 7 прогонов, 1 л oop каждый)

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit(nogil=True)
def distance(u, v, lon1, lat1, R=6.371*1e6):
    def d_lat(dlat,R=6.371*1e6):
        return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)             
    def d_lon(lat1,lat2,dlon,R=6.371*1e6):
        return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                               np.cos(np.deg2rad(lat2)) *
                               np.sin(np.deg2rad(dlon)/2)**2)
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = d_lat(lat2 - lat1)
    dlon = d_lon(lat1,lat2,lon2 - lon1)
    return dlon, dlat
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

1 мин 2 с ± 239 мс на л oop (среднее ± стандартное отклонение из 7 прогонов, 1 л oop каждый)

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

35,9 с ± 537 мс на л oop (среднее ± стандартное отклонение от 7 пробежек, 1 л oop каждый)

1 Ответ

1 голос
/ 24 апреля 2020

Есть несколько проблем, которые выскакивают.

Во-первых, ваши вычисления в функции distance излишне сложны и написаны в стиле (с большим количеством причудливых индексаций, например, lat2[v>0]), которые могут не быть идеальным для компилятора Numba. Хотя Numba становится умнее, я нахожу, что все еще высока отдача от написания кода простым, ориентированным на l oop способом.

Во-вторых, Numba может быть немного замедлена необязательными аргументами. Я обнаружил, что это верно в первую очередь для необязательной R в вашей функции distance.

Исправление этих двух проблем - в частности, замена вашего векторизованного кода на более простые циклы, которые минимизируют операции - мы получаем код form

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 

В моей (более медленной) системе это уменьшает время после начальной стоимости компиляции с примерно 7 секунд до примерно 4 секунд. На данный момент, я полагаю, что в стоимости преобладают исходные затраты на все функции np.sin, np.cos, np.exp и др. c.

...