Я пытаюсь оптимизировать функцию 'pw' в следующем коде, используя только NumPy функции (или, возможно, списки).
from time import time
import numpy as np
def pw(x, udata):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func
N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)
ti = time()
pw(x,udata)
tf = time()
print(tf - ti)
import cProfile
cProfile.run('pw(x,udata)')
Файл cProfile.run сообщает мне, что большинство накладные расходы приходят из np.where (около 1 мс), но я хотел бы создать более быстрый код, если это возможно. Кажется, что выполнение операций по строкам по сравнению со столбцами имеет некоторое значение, если я не ошибаюсь, но я думаю, что я это учел. Я знаю, что иногда списки могут быть быстрее, но я не мог найти более быстрый способ, чем то, что я делаю, используя их.
Похоже, что поиск с сортировкой дает лучшую производительность, но на моем компьютере все еще остается 1 мс:
(modified)
def pw(xx, uu):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds != uu.shape[0]]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))
pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func
Этот ответ , использующий кусочно, кажется довольно близким, но это на скаляре x0 и x1. Как бы я сделал это на массивах? И будет ли это более эффективным?
Понятно, что x может быть довольно большим, но я пытаюсь пройти его через стресс-тест.
Я все еще учусь, хотя некоторые подсказки или уловки, которые может помочь мне было бы здорово.
РЕДАКТИРОВАТЬ
Кажется, во второй функции есть ошибка, так как результирующий массив из второй функции не соответствует первый (который, я уверен, что он работает):
N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)
дает
15000
точек данных, которые не одинаковы. Так что, похоже, я не знаю, как использовать 'searchsorted'.
EDIT 2
На самом деле я это исправил:
pw_func = vals[inds[inds != uu.shape[0]]]
было изменено на
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
, поэтому по крайней мере полученные массивы совпадают. Но все еще остается вопрос о том, есть ли более эффективный способ сделать это.
РЕДАКТИРОВАТЬ 3
Спасибо Тин Лай за указание на ошибку. Этот должен работать
pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]
Возможно, более читабельный способ представить его будет
non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1 # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]
Я думаю, что я заблудился во всех этих скобках! Я думаю, что это важная читаемость.