Одним из способов ускорить вычисление является использование numba
, компилятора для своевременного выполнения Python.
Декоратор @jit
Numba предоставляет @jit
декоратор для компиляции некоторого Python кода и вывода оптимизированного машинного кода, который может выполняться параллельно на нескольких ЦПУ. Соединение с функцией интегрирования занимает мало усилий и позволит сэкономить немного времени, поскольку код оптимизирован для ускорения работы. С типами даже не нужно беспокоиться, Numba делает все это под капотом.
from scipy import integrate
from numba import jit
@jit
def circular_jit(x, y, a):
if x**2 + y**2 < a**2 / 4:
return 1
else:
return 0
a = 4
result = integrate.nquad(circular_jit, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
Это действительно работает быстрее, и при синхронизации на моей машине я получаю:
Original circular function: 1.599048376083374
Jitted circular function: 0.8280022144317627
Это на 50% сокращает время вычислений.
Scipy's LowLevelCallable
Вызовы функций в Python довольно трудоемки из-за особенностей языка. Затраты иногда могут сделать код Python медленным по сравнению со скомпилированными языками, такими как C.
Чтобы смягчить это, Scipy предоставляет класс LowLevelCallable
, который можно использовать для обеспечить доступ к скомпилированной функции обратного вызова низкого уровня. Благодаря этому механизму издержки вызова функции Python обходятся, и можно добиться дополнительной экономии времени.
Обратите внимание, что в случае nquad
подпись cfunc
передается LowerLevelCallable
. должен быть одним из:
double func(int n, double *xx)
double func(int n, double *xx, void *user_data)
, где int
- количество аргументов, а значения аргументов находятся во втором аргументе. user_data
используется для обратных вызовов, для работы которых требуется контекст.
Поэтому мы можем немного изменить сигнатуру циклической функции в Python, чтобы сделать ее совместимой.
from scipy import integrate, LowLevelCallable
from numba import cfunc
from numba.types import intc, CPointer, float64
@cfunc(float64(intc, CPointer(float64)))
def circular_cfunc(n, args):
x, y, a = (args[0], args[1], args[2]) # Cannot do `(args[i] for i in range(n))` as `yield` is not supported
if x**2 + y**2 < a**2/4:
return 1
else:
return 0
circular_LLC = LowLevelCallable(circular_cfunc.ctypes)
a = 4
result = integrate.nquad(circular_LLC, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
С помощью этого метода I get
LowLevelCallable circular function: 0.07962369918823242
Это уменьшение на 95% по сравнению с оригиналом и на 90% по сравнению с сопряженной версией функции.
Декоратор на заказ
В порядке чтобы сделать код более аккуратным и сохранить гибкость сигнатуры подынтегральной функции, можно создать специальную функцию декоратора. Он объединит функцию подынтегрального оператора и обернет его в LowLevelCallable
объект, который затем можно будет использовать с nquad
.
from scipy import integrate, LowLevelCallable
from numba import cfunc, jit
from numba.types import intc, CPointer, float64
def jit_integrand_function(integrand_function):
jitted_function = jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1], xx[2])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def circular(x, y, a):
if x**2 + y**2 < a**2 / 4:
return 1
else:
return 0
a = 4
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
Произвольное количество аргументов
Если число аргументов равно неизвестно, тогда мы можем использовать удобную carray
функцию , предоставленную Numba для преобразования CPointer(float64)
в Numpy массив.
import numpy as np
from scipy import integrate, LowLevelCallable
from numba import cfunc, carray, jit
from numba.types import intc, CPointer, float64
def jit_integrand_function(integrand_function):
jitted_function = jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
ar = carray(xx, n)
return jitted_function(ar[0], ar[1], ar[2:])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def circular(x, y, a):
if x**2 + y**2 < a[-1]**2 / 4:
return 1
else:
return 0
ar = np.array([1, 2, 3, 4])
a = ar[-1]
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=ar)