Краткий обзор кода и попытка его китонизации, простое добавление типов ndarray ко всем параметрам и переменным, не приводит к значительному изменению производительности. Если вы боретесь за сокращение микросекунд для этой функции в этом тесном внутреннем цикле, я бы рассмотрел следующие изменения:
- Причиной, по которой этот код так сложно цитонизировать, является то, что ваш код векторизован. Все операции проходят через
numpy
или numexpr
. Несмотря на то, что каждая из этих операций эффективна, все они добавляют некоторые накладные расходы на Python (что можно увидеть, если посмотреть на аннотированные .html
файлы, которые может создавать Cython).
- Если вы вызываете эту функцию много раз (как это выглядит на основе ваших комментариев), вы можете сэкономить некоторое время, сделав вместо этого
mktout
cdef
функцию. Вызовы функций Python имеют значительные накладные расходы.
- Незначительно, но вы можете попытаться избежать любых функций из модуля
math
Python. Вы можете заменить это на from libc cimport math as cmath
и использовать вместо него cmath.exp
.
- Я вижу, что ваша
mktout
функция принимает в список Python mean_mu_alpha
. Вы можете использовать объект cdef class
для замены этого параметра и ввести его вместо этого. Если вместо этого вы решите сделать функцию mktout
a cdef
, она может стать просто структурой или double *
массивом. В любом случае, индексирование в список python (который может содержать произвольные объекты python, которые необходимо распаковать в соответствующие c-типы) будет медленным.
- Возможно, это самая важная часть. Для каждого вызова
mktout
вы выделяете память для большого количества массивов (для каждого mu
, alpha
, threshold
, case
, t-
и p-
массив). Затем вы приступаете к освобождению всей этой памяти в конце функции (через gc Python), только чтобы снова использовать все это пространство при следующем вызове. Если вы можете изменить сигнатуру mktout
, вы можете передать все эти массивы в качестве параметров, чтобы память можно было повторно использовать и перезаписывать при вызовах функций . Другим вариантом, который лучше для этого случая, будет итерация по массиву и выполнение всех вычислений по одному элементу за раз.
- Вы можете использовать многопоточность кода, используя функцию
prange
в Cython. Я достигну этого после того, как вы внесете все вышеперечисленные изменения, и я бы сделал многопоточность вне самой функции mktout
. То есть вы будете использовать многопоточность вызовов mktout вместо многопоточности mktout
.
Внесение вышеуказанных изменений будет большой работой, и вам, вероятно, придется переопределить многие функции, предоставляемые numpy и Numberxpr, чтобы избежать накладных расходов на python, связанных с каждым разом. Пожалуйста, дайте мне знать, если какая-то часть этого неясна.
Обновление № 1: При реализации пунктов № 1, № 3 и № 5 я получаю 11-кратное ускорение . Вот как выглядит этот код. Я уверен, что он может пойти быстрее, если вы отключите функцию def
, вход list mean_mu_alpha
и выход tuple
. Примечание. В последнем десятичном знаке я получаю немного другие результаты по сравнению с исходным кодом, вероятно из-за некоторых правил с плавающей запятой, которые я не понимаю.
from libc cimport math as cmath
from libc.stdint cimport *
from libc.stdlib cimport *
def mktout(list mean_mu_alpha, double[:, ::1] errors, double par_gamma):
cdef:
size_t i, n
double[4] exp
double exp_par_gamma
double mu10, mu11, mu20, mu21
double alpha1, alpha2
bint j_is_larger, j_is_smaller
double threshold2, threshold3
bint case1, case2, case3, case4, case5, case6
double t0, t1, t2
double p12, p1, p2
double t1_sum, t2_sum, p1_sum, p2_sum
double c
#compute the exp outside of the loop
n = errors.shape[0]
exp[0] = cmath.exp(<double>mean_mu_alpha[0])
exp[1] = cmath.exp(<double>mean_mu_alpha[1])
exp[2] = cmath.exp(<double>mean_mu_alpha[2])
exp[3] = cmath.exp(<double>mean_mu_alpha[3])
exp_par_gamma = cmath.exp(par_gamma)
c = 168.0
t1_sum = 0.0
t2_sum = 0.0
p1_sum = 0.0
p2_sum = 0.0
for i in range(n):
mu10 = errors[i, 0] * exp[0]
mu11 = exp_par_gamma * mu10
mu20 = errors[i, 1] * exp[1]
mu21 = exp_par_gamma * mu20
alpha1 = errors[i, 2] * exp[2]
alpha2 = errors[i, 3] * exp[3]
j_is_larger = mu10 > mu20
j_is_smaller = not j_is_larger
threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
threshold3 = (1 + mu20 * alpha2) / (c + alpha2)
case1 = j_is_larger * (mu10 < 1 / c)
case2 = j_is_larger * (mu21 >= threshold2)
case3 = j_is_larger ^ (case1 | case2)
case4 = j_is_smaller * (mu20 < 1 / c)
case5 = j_is_smaller * (mu11 >= threshold3)
case6 = j_is_smaller ^ (case4 | case5)
t0 = case1*c+case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) +case3 / threshold2 +case4 * c +case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
t1 = case2 * (t0 * alpha1 * mu11 - alpha1) +case3 * (t0 * alpha1 * mu10 - alpha1) +case5 * (t0 * alpha1 * mu11 - alpha1)
t2 = c - t0 - t1
p12 = case2 + case5
p1 = case3 + p12
p2 = case6 + p12
t1_sum += t1
t2_sum += t2
p1_sum += p1
p2_sum += p2
return t1_sum/n, t2_sum/n, p1_sum/n, p2_sum/n
Обновление № 2: Реализованы идеи cdef
(# 2), исключения объектов Python (# 4) и многопоточности (# 6). Только # 2 и # 4 имели незначительную выгоду, но были необходимы для # 6, так как GIL не может быть доступен в циклах OpenMP prange
. Благодаря многопоточности вы получаете дополнительное увеличение скорости в 2,5 раза на моем четырехъядерном ноутбуке, что составляет код, который примерно в 27,5 раза быстрее оригинального. Моя функция outer_loop
не совсем точна, хотя она просто пересчитывает один и тот же результат снова и снова, но этого должно быть достаточно для тестового примера. Полный код ниже:
from libc cimport math as cmath
from libc.stdint cimport *
from libc.stdlib cimport *
from cython.parallel cimport prange
def mktout(list mean_mu_alpha, double[:, ::1] errors, double par_gamma):
cdef:
size_t i, n
double[4] exp
double exp_par_gamma
double mu10, mu11, mu20, mu21
double alpha1, alpha2
bint j_is_larger, j_is_smaller
double threshold2, threshold3
bint case1, case2, case3, case4, case5, case6
double t0, t1, t2
double p12, p1, p2
double t1_sum, t2_sum, p1_sum, p2_sum
double c
#compute the exp outside of the loop
n = errors.shape[0]
exp[0] = cmath.exp(<double>mean_mu_alpha[0])
exp[1] = cmath.exp(<double>mean_mu_alpha[1])
exp[2] = cmath.exp(<double>mean_mu_alpha[2])
exp[3] = cmath.exp(<double>mean_mu_alpha[3])
exp_par_gamma = cmath.exp(par_gamma)
c = 168.0
t1_sum = 0.0
t2_sum = 0.0
p1_sum = 0.0
p2_sum = 0.0
for i in range(n):
mu10 = errors[i, 0] * exp[0]
mu11 = exp_par_gamma * mu10
mu20 = errors[i, 1] * exp[1]
mu21 = exp_par_gamma * mu20
alpha1 = errors[i, 2] * exp[2]
alpha2 = errors[i, 3] * exp[3]
j_is_larger = mu10 > mu20
j_is_smaller = not j_is_larger
threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
threshold3 = (1 + mu20 * alpha2) / (c + alpha2)
case1 = j_is_larger * (mu10 < 1 / c)
case2 = j_is_larger * (mu21 >= threshold2)
case3 = j_is_larger ^ (case1 | case2)
case4 = j_is_smaller * (mu20 < 1 / c)
case5 = j_is_smaller * (mu11 >= threshold3)
case6 = j_is_smaller ^ (case4 | case5)
t0 = case1*c+case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) +case3 / threshold2 +case4 * c +case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
t1 = case2 * (t0 * alpha1 * mu11 - alpha1) +case3 * (t0 * alpha1 * mu10 - alpha1) +case5 * (t0 * alpha1 * mu11 - alpha1)
t2 = c - t0 - t1
p12 = case2 + case5
p1 = case3 + p12
p2 = case6 + p12
t1_sum += t1
t2_sum += t2
p1_sum += p1
p2_sum += p2
return t1_sum/n, t2_sum/n, p1_sum/n, p2_sum/n
ctypedef struct Vec4:
double a
double b
double c
double d
def outer_loop(list mean_mu_alpha, double[:, ::1] errors, double par_gamma, size_t n):
cdef:
size_t i
Vec4 mean_vec
Vec4 out
mean_vec.a = <double>(mean_mu_alpha[0])
mean_vec.b = <double>(mean_mu_alpha[1])
mean_vec.c = <double>(mean_mu_alpha[2])
mean_vec.d = <double>(mean_mu_alpha[3])
with nogil:
for i in prange(n):
cy_mktout(&out, &mean_vec, errors, par_gamma)
return out
cdef void cy_mktout(Vec4 *out, Vec4 *mean_mu_alpha, double[:, ::1] errors, double par_gamma) nogil:
cdef:
size_t i, n
double[4] exp
double exp_par_gamma
double mu10, mu11, mu20, mu21
double alpha1, alpha2
bint j_is_larger, j_is_smaller
double threshold2, threshold3
bint case1, case2, case3, case4, case5, case6
double t0, t1, t2
double p12, p1, p2
double t1_sum, t2_sum, p1_sum, p2_sum
double c
#compute the exp outside of the loop
n = errors.shape[0]
exp[0] = cmath.exp(mean_mu_alpha.a)
exp[1] = cmath.exp(mean_mu_alpha.b)
exp[2] = cmath.exp(mean_mu_alpha.c)
exp[3] = cmath.exp(mean_mu_alpha.d)
exp_par_gamma = cmath.exp(par_gamma)
c = 168.0
t1_sum = 0.0
t2_sum = 0.0
p1_sum = 0.0
p2_sum = 0.0
for i in range(n):
mu10 = errors[i, 0] * exp[0]
mu11 = exp_par_gamma * mu10
mu20 = errors[i, 1] * exp[1]
mu21 = exp_par_gamma * mu20
alpha1 = errors[i, 2] * exp[2]
alpha2 = errors[i, 3] * exp[3]
j_is_larger = mu10 > mu20
j_is_smaller = not j_is_larger
threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
threshold3 = (1 + mu20 * alpha2) / (c + alpha2)
case1 = j_is_larger * (mu10 < 1 / c)
case2 = j_is_larger * (mu21 >= threshold2)
case3 = j_is_larger ^ (case1 | case2)
case4 = j_is_smaller * (mu20 < 1 / c)
case5 = j_is_smaller * (mu11 >= threshold3)
case6 = j_is_smaller ^ (case4 | case5)
t0 = case1*c+case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) +case3 / threshold2 +case4 * c +case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
t1 = case2 * (t0 * alpha1 * mu11 - alpha1) +case3 * (t0 * alpha1 * mu10 - alpha1) +case5 * (t0 * alpha1 * mu11 - alpha1)
t2 = c - t0 - t1
p12 = case2 + case5
p1 = case3 + p12
p2 = case6 + p12
t1_sum += t1
t2_sum += t2
p1_sum += p1
p2_sum += p2
out.a = t1_sum/n
out.b = t2_sum/n
out.c = p1_sum/n
out.d = p2_sum/n
И файл setup.py
, который я использую, выглядит следующим образом (содержит все флаги оптимизации и OpenMP):
from distutils.core import setup
from Cython.Build import cythonize
from distutils.core import Extension
import numpy as np
import os
import shutil
import platform
libraries = {
"Linux": [],
"Windows": [],
}
language = "c"
args = ["-w", "-std=c11", "-O3", "-ffast-math", "-march=native", "-fopenmp"]
link_args = ["-std=c11", "-fopenmp"]
annotate = True
directives = {
"binding": True,
"boundscheck": False,
"wraparound": False,
"initializedcheck": False,
"cdivision": True,
"nonecheck": False,
"language_level": "3",
#"c_string_type": "unicode",
#"c_string_encoding": "utf-8",
}
if __name__ == "__main__":
system = platform.system()
libs = libraries[system]
extensions = []
ext_modules = []
#create extensions
for path, dirs, file_names in os.walk("."):
for file_name in file_names:
if file_name.endswith("pyx"):
ext_path = "{0}/{1}".format(path, file_name)
ext_name = ext_path \
.replace("./", "") \
.replace("/", ".") \
.replace(".pyx", "")
ext = Extension(
name=ext_name,
sources=[ext_path],
libraries=libs,
language=language,
extra_compile_args=args,
extra_link_args=link_args,
include_dirs = [np.get_include()],
)
extensions.append(ext)
#setup all extensions
ext_modules = cythonize(
extensions,
annotate=annotate,
compiler_directives=directives,
)
setup(ext_modules=ext_modules)
"""
#immediately remove build directory
build_dir = "./build"
if os.path.exists(build_dir):
shutil.rmtree(build_dir)
"""
Обновление № 3: По совету @ GZ0 было много условий, в которых выражения в коде будут обнуляться и будут расточительно вычисляться.Я попытался удалить эти области с помощью следующего кода (после исправления операторов case3
и case6
):
cdef void cy_mktout_if(Vec4 *out, Vec4 *mean_mu_alpha, double[:, ::1] errors, double par_gamma) nogil:
cdef:
size_t i, n
double[4] exp
double exp_par_gamma
double mu10, mu11, mu20, mu21
double alpha1, alpha2
bint j_is_larger
double threshold2, threshold3
bint case1, case2, case3, case4, case5, case6
double t0, t1, t2
double p12, p1, p2
double t1_sum, t2_sum, p1_sum, p2_sum
double c
#compute the exp outside of the loop
n = errors.shape[0]
exp[0] = cmath.exp(mean_mu_alpha.a)
exp[1] = cmath.exp(mean_mu_alpha.b)
exp[2] = cmath.exp(mean_mu_alpha.c)
exp[3] = cmath.exp(mean_mu_alpha.d)
exp_par_gamma = cmath.exp(par_gamma)
c = 168.0
t1_sum = 0.0
t2_sum = 0.0
p1_sum = 0.0
p2_sum = 0.0
for i in range(n):
mu10 = errors[i, 0] * exp[0]
mu11 = exp_par_gamma * mu10
mu20 = errors[i, 1] * exp[1]
mu21 = exp_par_gamma * mu20
alpha1 = errors[i, 2] * exp[2]
alpha2 = errors[i, 3] * exp[3]
j_is_larger = mu10 > mu20
j_is_smaller = not j_is_larger
threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
threshold3 = (1 + mu20 * alpha2) / (c + alpha2)
if j_is_larger:
case1 = mu10 < 1 / c
case2 = mu21 >= threshold2
case3 = not (case1 | case2)
t0 = case1*c + case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case3 / threshold2
t1 = case2 * (t0 * alpha1 * mu11 - alpha1) + case3 * (t0 * alpha1 * mu10 - alpha1)
t2 = c - t0 - t1
t1_sum += t1
t2_sum += t2
p1_sum += case2 + case3
p2_sum += case2
else:
case4 = mu20 < 1 / c
case5 = mu11 >= threshold3
case6 = not (case4 | case5)
t0 = case4 * c + case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
t1 = case5 * (t0 * alpha1 * mu11 - alpha1)
t2 = c - t0 - t1
t1_sum += t1
t2_sum += t2
p1_sum += case5
p2_sum += case5 + case6
out.a = t1_sum/n
out.b = t2_sum/n
out.c = p1_sum/n
out.d = p2_sum/n
Для 10000 итераций текущий код работает следующим образом:
outer_loop: 0.5116949229995953 seconds
outer_loop_if: 0.617649456995423 seconds
mktout: 0.9221872320049442 seconds
mktout_if: 1.430276553001022 seconds
python: 10.116664300003322 seconds
Я думаю, что стоимость условного и отраслевого неверного прогноза, что в результате делает функцию удивительно медленной, но я был бы признателен за любую помощь, чтобы наверняка это прояснить.