Оптимизация рекурсивной функции в Юлии - PullRequest
0 голосов
/ 03 июля 2018

Я написал код Джулии, который вычисляет интегралы по гауссовским функциям, и у меня есть своего рода функция ядра, которая вызывается снова и снова. Согласно встроенному Profile модулю Julia, именно здесь я провожу большую часть времени во время фактических вычислений, и поэтому я хотел бы посмотреть, есть ли способ улучшить его.

Это рекурсивная функция, и я реализовал ее простым способом. Поскольку я не настолько привык к рекурсивным функциям, возможно, у кого-то есть какие-то идеи / предложения о том, как его улучшить (как с чисто теоретической алгоритмической точки зрения, так и / или с использованием специальных оптимизаций из JIT-компилятора).

Вот оно:

"""Returns the integral of an Hermite Gaussian divided by the Coulomb operator."""
function Rtuv{T<:Real}(t::Int, u::Int, v::Int, n::Int, p::Real, RPC::Vector{T})
    if t == u == v == 0
        return (-2.0*p)^n * boys(n,p*norm(RPC)^2)
    elseif u == v == 0
        if t > 1
            return  (t-1)*Rtuv(t-2, u, v, n+1, p, RPC) +
                   RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        else
            return RPC[1]*Rtuv(t-1, u, v, n+1, p, RPC)
        end
    elseif v == 0
        if u > 1
            return  (u-1)*Rtuv(t, u-2, v, n+1, p, RPC) +
                   RPC[2]*Rtuv(t, u-1, v, n+1, p, RPC)
        else
            return RPC[2]*Rtuv(t, u-1, v, n+1, p ,RPC)
        end
    else
        if v > 1
            return  (v-1)*Rtuv(t, u, v-2, n+1, p, RPC)
                   RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        else
            return RPC[3]*Rtuv(t, u, v-1, n+1, p, RPC)
        end
    end
end

Не обращайте так много внимания на функцию boys, поскольку, согласно профилировщику, она не такая уж и тяжелая.
Просто чтобы дать представление о диапазоне номеров: обычно первый звонок поступает с t+u+v в диапазоне от 0 до 3, тогда как n всегда начинается с 0.

Приветствия

РЕДАКТИРОВАТЬ - Новая информация

Сгенерированная версия медленнее для небольших значений t,u,v, я думаю, причина в том, что выражения не оптимизируются компилятором. Я плохо тестировал для этого случая, не интерполируя переданный аргумент. Делая это правильно, я всегда быстрее с подходом, описанным в принятом ответе, так что ура!

В целом, компилятор определяет тривиальные случаи, такие как умножение на нули и единицы, и оптимизирует их?

Ответьте себе: из быстрой проверки простого кода с помощью @code_llvm похоже, что это не так.

1 Ответ

0 голосов
/ 03 июля 2018

Возможно, это работает в вашем случае: вы можете «запоминать» целые скомпилированные методы, используя сгенерированные функции, и избавляться от всех рекурсий после первого вызова.

Поскольку t, u и v будут оставаться маленькими, вы можете сгенерировать полностью расширенный код для рекурсий. Предположим для простоты фиктивная реализация

boys(n::Int, x::Real) = n + x

Тогда

function Rtuv_expr(t::Int, u::Int, v::Int, n, p, RPC)
    ninc = :($n + 1)

    if t == u == v == 0
        :((-2.0 * $p)^$n * boys($n, $p * norm($RPC)^2))
    elseif u == v == 0
        if t > 1
            :($(t-1) * $(Rtuv_expr(t-2, u, v, ninc, p, RPC)) +
              $RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        else
            :($RPC[1] * $(Rtuv_expr(t-1, u, v, ninc, p, RPC)))
        end
    elseif v == 0
        if u > 1
            :($(u-1) * $(Rtuv_expr(t, u-2, v, ninc, p, RPC)) +
              $RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        else
            :($RPC[2] * $(Rtuv_expr(t, u-1, v, ninc, p, RPC)))
        end
    else
        if v > 1 
            :($(v-1) * $(Rtuv_expr(t, u, v-2, ninc, p, RPC)) + 
              $RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        else
            :($RPC[3] * $(Rtuv_expr(t, u, v-1, ninc, p, RPC)))
        end
    end
end

сгенерирует вам полностью расширенные выражения, подобные этому:

julia> Rtuv_expr(1, 2, 1, 0, 0.1, rand(3))
:(([0.868194, 0.928591, 0.295344])[3] * (1 * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ (((0 + 1) + 1) + 1) * boys(((0 + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))) + ([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[2] * (([0.868194, 0.928591, 0.295344])[1] * ((-2.0 * 0.1) ^ ((((0 + 1) + 1) + 1) + 1) * boys((((0 + 1) + 1) + 1) + 1, 0.1 * norm([0.868194, 0.928591, 0.295344]) ^ 2))))))

Мы можем вставить это в сгенерированную функцию Rtuv, принимающую Val типы. Для каждой отдельной комбинации T, U и V эта функция будет использовать Rtuv_expr для компиляции соответствующего выражения и с тех пор использовать этот метод - больше нет рекурсии:

@generated function Rtuv{T, U, V, X<:Real}(::Type{Val{T}}, ::Type{Val{U}}, ::Type{Val{V}},
                                           n::Int, p::Real, RPC::Vector{X})
    Rtuv_expr(T, U, V, :n, :p, :RPC)
end

Вы должны вызвать его с t, u, v, завернутым в Val, хотя:

julia> Rtuv(Val{1}, Val{2}, Val{1}, 0, 0.1, rand(3))
-0.0007782250832001092

Если вы тестируете маленькую петлю, как это,

for t = 0:3, u = 0:3, v = 0:3
    println(Rtuv(Val{t}, Val{u}, Val{v}, 0, 0.1, [1.0, 2.0, 3.0]))
end

для первого запуска потребуется некоторое время, но потом все пойдет довольно быстро, поскольку используемые методы уже скомпилированы.

...