Как оптимизировать медленный двойной интеграл в python, который использует много интерполяции?Возможно, используя Numba? - PullRequest
0 голосов
/ 21 декабря 2018

Мне крайне необходимо ускорить некоторый код, который включает двойной интеграл с 1D-интерполяцией в обоих шагах.Я сократил его до следующего минимального рабочего примера:

import numpy as np

from scipy.interpolate import interp1d as interp
from scipy.integrate import romberg


def calculate_p22(k):
    data='input_data.dat'
    data=np.loadtxt(data, unpack=True)
    p_int = interp(data[0], data[1])
    lnqmax=np.log(0.2*k)
    lnqmin=-9.0
    result=romberg(dq, lnqmin, lnqmax, args=(k, p_int))
    return result


def dq(lnq, k, p_int):
    q=np.exp(lnq)
    integral_theta=romberg(dtheta, 0.0, np.pi, args=(q, k, p_int))
    return q**3.*p_int(q)*integral_theta


def dtheta(theta, q, k, p_int):
    p=np.sqrt(k**2.+q**2.-2.*k*q*np.cos(theta))
    return p_int(p)

Вызов romberg действительно повышает точность предупреждений, но результаты верны до уровня, которым я доволен.Даже с увеличением допуска код все еще остается медленным.Я пытался понять, как использовать Numba, но столкнулся с трудностью, которую Numba не может использовать interp1d.Я искал альтернативные функции интерполяции Numba в Интернете, но столкнулся с трудностью из-за того, что при вызове функции они выдают желаемые значения x для вывода.Scipy.interpolate.interp1d лучше подходит для адаптивных функций интеграции, потому что я могу создать функцию, а затем вычислять значения в функции интегрирования на лету.

Я был бы очень признателен за помощь в ускорении этого процесса (были бы полезны советы как по Нумбе, так и за пределами Нумбы!)Я работал над расширенной версией кода в течение нескольких недель и собираюсь переписать все на C или Fortran.Но я бы предпочел, чтобы это работало!

Спасибо.


Обновление: Вот вывод из cProfile:

         19720831 function calls (19700522 primitive calls) in 20.513 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   20.513   20.513 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 _iotools.py:31(_is_string_like)
     7666    0.003    0.000    0.028    0.000 _methods.py:31(_sum)
  1054196    0.364    0.000    2.274    0.000 _methods.py:37(_any)
   527098    2.182    0.000    4.804    0.000 _util.py:192(_asarray_validated)
   527098    0.212    0.000    0.484    0.000 base.py:1111(isspmatrix)
   527098    0.195    0.000    0.358    0.000 core.py:6192(isMaskedArray)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:1143(squeeze)
     7666    0.025    0.000    0.056    0.000 fromnumeric.py:1730(sum)
        2    0.000    0.000    0.000    0.000 fromnumeric.py:55(_wrapfunc)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:70(take)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:826(argsort)
   527098    1.120    0.000    2.605    0.000 function_base.py:1934(interp)
        1    0.000    0.000    0.000    0.000 interpolate.py:298(_check_broadcast_up_to)
        2    0.000    0.000    0.000    0.000 interpolate.py:315(_do_extrapolate)
        1    0.000    0.000    0.000    0.000 interpolate.py:403(__init__)
        1    0.000    0.000    0.000    0.000 interpolate.py:511(fill_value)
   527098    0.348    0.000    2.953    0.000 interpolate.py:545(_call_linear_np)
   527098    1.182    0.000    9.746    0.000 interpolate.py:603(_evaluate)
   527098    2.489    0.000    5.250    0.000 interpolate.py:618(_check_bounds)
     1025    0.010    0.000   20.492    0.020 minimum_example.py:17(dq)
   526073    2.759    0.000   19.557    0.000 minimum_example.py:23(dtheta)
        1    0.000    0.000   20.513   20.513 minimum_example.py:7(calculate_p22)
        1    0.000    0.000    0.000    0.000 npyio.py:718(_getconv)
     1792    0.002    0.000    0.002    0.000 npyio.py:721(floatconv)
        1    0.005    0.005    0.018    0.018 npyio.py:748(loadtxt)
        2    0.000    0.000    0.000    0.000 npyio.py:858(<genexpr>)
        1    0.000    0.000    0.000    0.000 npyio.py:906(flatten_dtype_internal)
 1792/896    0.004    0.000    0.004    0.000 npyio.py:935(pack_items)
      897    0.003    0.000    0.004    0.000 npyio.py:951(split_line)
        1    0.000    0.000    0.000    0.000 numeric.py:1432(rollaxis)
     9718    0.017    0.000    0.022    0.000 numeric.py:2135(isscalar)
  1588963    0.792    0.000    1.410    0.000 numeric.py:463(asarray)
   527100    0.219    0.000    0.367    0.000 numerictypes.py:660(issubclass_)
   527100    0.481    0.000    0.946    0.000 numerictypes.py:728(issubdtype)
        1    0.000    0.000    0.000    0.000 polyint.py:105(_reshape_yi)
        1    0.000    0.000    0.000    0.000 polyint.py:113(_set_yi)
        1    0.000    0.000    0.000    0.000 polyint.py:133(_set_dtype)
        1    0.000    0.000    0.000    0.000 polyint.py:55(__init__)
   527098    0.844    0.000   16.836    0.000 polyint.py:62(__call__)
   527098    0.557    0.000    5.545    0.000 polyint.py:88(_prepare_x)
   527098    0.429    0.000    0.701    0.000 polyint.py:94(_finish_y)
        1    0.000    0.000    0.000    0.000 py3k.py:94(is_pathlib_path)
  9718/12    0.589    0.000   20.494    1.708 quadrature.py:117(vfunc)
  8692/11    0.063    0.000   20.495    1.863 quadrature.py:544(_difftrap)
    36282    0.029    0.000    0.029    0.000 quadrature.py:570(_romberg_diff)
   1026/1    0.082    0.000   20.495   20.495 quadrature.py:596(romberg)
     1026    0.001    0.000    0.001    0.000 quadrature.py:88(vectorize1)
        1    0.000    0.000    0.000    0.000 re.py:192(compile)
        1    0.000    0.000    0.000    0.000 re.py:208(escape)
        1    0.000    0.000    0.000    0.000 re.py:230(_compile)
   527098    0.363    0.000    0.483    0.000 type_check.py:251(iscomplexobj)
      221    0.002    0.000    0.002    0.000 {_warnings.warn}
    15332    0.003    0.000    0.003    0.000 {abs}
     7668    0.002    0.000    0.002    0.000 {getattr}
  2125781    0.797    0.000    0.797    0.000 {isinstance}
  1581305    0.367    0.000    0.367    0.000 {issubclass}
        1    0.000    0.000    0.000    0.000 {iter}
   536565    0.059    0.000    0.059    0.000 {len}
  1054196    0.487    0.000    2.761    0.000 {method 'any' of 'numpy.ndarray' objects}
    38074    0.008    0.000    0.008    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'argsort' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 {method 'close' of 'file' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        2    0.000    0.000    0.000    0.000 {method 'endswith' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
      2/1    0.000    0.000    0.000    0.000 {method 'join' of 'str' objects}
     1792    0.000    0.000    0.000    0.000 {method 'lower' of 'str' objects}
   527099    0.184    0.000    0.184    0.000 {method 'ravel' of 'numpy.ndarray' objects}
  1061862    1.935    0.000    1.935    0.000 {method 'reduce' of 'numpy.ufunc' objects}
   527099    0.272    0.000    0.272    0.000 {method 'reshape' of 'numpy.ndarray' objects}
      897    0.001    0.000    0.001    0.000 {method 'split' of '_sre.SRE_Pattern' objects}
      897    0.000    0.000    0.000    0.000 {method 'split' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'squeeze' of 'numpy.ndarray' objects}
      897    0.000    0.000    0.000    0.000 {method 'strip' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'take' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 {next}
     7666    0.016    0.000    0.016    0.000 {numpy.core.multiarray.arange}
  1588966    0.618    0.000    0.618    0.000 {numpy.core.multiarray.array}
     7666    0.010    0.000    0.010    0.000 {numpy.core.multiarray.empty}
   527098    0.381    0.000    0.381    0.000 {numpy.core.multiarray.interp}
        1    0.000    0.000    0.000    0.000 {numpy.core.multiarray.normalize_axis_index}
        1    0.000    0.000    0.000    0.000 {open}
        2    0.000    0.000    0.000    0.000 {range}
      897    0.001    0.000    0.001    0.000 {zip}

Обновление:Теперь я запустил snakeviz, и вывод можно увидеть здесь:

Вывод Snakeviz

...