Оптимизируйте код для пошаговой функции, используя только NumPy - PullRequest
2 голосов
/ 17 февраля 2020

Я пытаюсь оптимизировать функцию 'pw' в следующем коде, используя только NumPy функции (или, возможно, списки).

from time import time
import numpy as np

def pw(x, udata):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

Файл cProfile.run сообщает мне, что большинство накладные расходы приходят из np.where (около 1 мс), но я хотел бы создать более быстрый код, если это возможно. Кажется, что выполнение операций по строкам по сравнению со столбцами имеет некоторое значение, если я не ошибаюсь, но я думаю, что я это учел. Я знаю, что иногда списки могут быть быстрее, но я не мог найти более быстрый способ, чем то, что я делаю, используя их.

Похоже, что поиск с сортировкой дает лучшую производительность, но на моем компьютере все еще остается 1 мс:

(modified)
def pw(xx, uu):
    """
    Creates the step function
                 | 1,  if d0 <= x < d1
                 | 2,  if d1 <= x < d2
    pw(x,data) = ...
                 | N, if d(N-1) <= x < dN
                 | 0, otherwise
    where di is the ith element in data.
    INPUT:      x   --  interval which the step function is defined over
              data  --  an ordered set of data (without repetitions)
    OUTPUT: pw_func --  an array of size x.shape[0]
    """
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds != uu.shape[0]]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

Этот ответ , использующий кусочно, кажется довольно близким, но это на скаляре x0 и x1. Как бы я сделал это на массивах? И будет ли это более эффективным?

Понятно, что x может быть довольно большим, но я пытаюсь пройти его через стресс-тест.

Я все еще учусь, хотя некоторые подсказки или уловки, которые может помочь мне было бы здорово.

РЕДАКТИРОВАТЬ

Кажется, во второй функции есть ошибка, так как результирующий массив из второй функции не соответствует первый (который, я уверен, что он работает):

N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

дает

15000

точек данных, которые не одинаковы. Так что, похоже, я не знаю, как использовать 'searchsorted'.

EDIT 2

На самом деле я это исправил:

pw_func = vals[inds[inds != uu.shape[0]]]

было изменено на

pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

, поэтому по крайней мере полученные массивы совпадают. Но все еще остается вопрос о том, есть ли более эффективный способ сделать это.

РЕДАКТИРОВАТЬ 3

Спасибо Тин Лай за указание на ошибку. Этот должен работать

pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

Возможно, более читабельный способ представить его будет

non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1       # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

Я думаю, что я заблудился во всех этих скобках! Я думаю, что это важная читаемость.

1 Ответ

2 голосов
/ 17 февраля 2020

Очень абстрактная, но интересная проблема! Спасибо за то, что развлекли меня, я повеселился:)

ps Я не уверен насчет вашего pw2 Я не смог получить его на выходе так же, как pw1.

Для ссылка на оригинал pw s:

def pw1(x, udata):
    vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
    pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
    return pw_func

def pw2(xx, uu):
    inds = np.searchsorted(uu, xx, side='right')
    vals = np.arange(1,uu.shape[0]+1)
    pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
    num_mins = np.sum(xx < np.min(uu))
    num_maxs = np.sum(xx > np.max(uu))

    pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
    return pw_func

Моя первая попытка была использовать много операций вещания с numpy:

def pw3(x, udata):
    # the None slice is to create new axis
    step_bool = x >= udata[None,:].T

    # we exploit the fact that bools are integer value of 1s
    # skipping the last value in "data"
    step_vals = np.sum(step_bool[:-1], axis=0)

    # for the step_bool that we skipped from previous step (last index)
    # we set it to zerp so that we can negate the step_vals once we reached
    # the last value in "data"
    step_vals[step_bool[-1]] = 0

    return step_vals

После просмотра searchsorted из вашего pw2 У меня был новый подход, который использует его с гораздо более высокой производительностью:

def pw4(x, udata):
    inds = np.searchsorted(udata, x, side='right')

    # fix-ups the last data if x is already out of range of data[-1]
    if x[-1] > udata[-1]:
        inds[inds == inds[-1]] = 0

    return inds

Графики с:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

с data = [1,3,4,5,5,7]:

enter image description here

с data = [1,3,4,5,5,7,11]

enter image description here

pw1, pw3, pw4 все идентичны

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

Производительность: (timeit по умолчанию выполняется 3 раза, в среднем number=N раз)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...