Я пытаюсь вычислить градиент функционала решения стохастического c дифференциального уравнения (SDE), учитывая конкретную c реализацию шума. Я могу успешно рассчитать эти градиенты, если я оставлю шум неопределенным, как показано в DiffEqFlux.jl: Использование других дифференциальных уравнений . Я также могу успешно получить решение для моей SDE для конкретной реализации c шума, как показано в DifferentialEquations.jl: Пример NoiseWrapper . Однако когда я пытаюсь соединить их вместе, код возвращает ошибку.
Вот минимальный рабочий пример, адаптированный из двух отдельных примеров, упомянутых выше:
using StochasticDiffEq, DiffEqBase, DiffEqNoiseProcess, DiffEqSensitivity, Zygote
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
function lotka_volterra_noise(du,u,p,t)
du[1] = 0.1u[1]
du[2] = 0.1u[2]
end
dt = 1//2^(4)
u0 = [1.0,1.0]
p = [2.2, 1.0, 2.0, 0.4]
prob1 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p)
sol1 = solve(prob1,EM(),dt=dt,save_noise=true)
W2 = NoiseWrapper(sol1.W)
prob2 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p,noise=W2)
sol2 = solve(prob2,EM(),dt=dt)
function predict_sde1(p)
Array(concrete_solve(remake(prob1,p=p),EM(),dt=dt,sensealg=ForwardDiffSensitivity(),saveat=0.1))
end
loss_sde1(p) = sum(abs2,x-1 for x in predict_sde1(p))
loss_sde1(p)
# This gradient is successfully calculated
Zygote.gradient(loss_sde1,p)
function predict_sde2(p)
W2 = NoiseWrapper(sol1.W)
Array(concrete_solve(remake(prob2,p=p,noise=W2),EM(),dt=dt,sensealg=ForwardDiffSensitivity(),saveat=0.1))
end
loss_sde2(p) = sum(abs2,x-1 for x in predict_sde2(p))
# This loss is successfully calculated
loss_sde2(p)
# This gradient calculation raises and error
Zygote.gradient(loss_sde2,p)
Ошибка I получить в конце выполнения этот код
TypeError: in setfield!, expected Float64, got ForwardDiff.Dual{Nothing,Float64,4}
Stacktrace:
[1] setproperty! at ./Base.jl:21 [inlined]
...
с последующим бесконечным выводом на трассировку стека (я могу опубликовать его, если вы считаете, что это будет полезно, но, поскольку он длиннее остальной части этого вопроса, я я бы не хотел загромождать вещи с ног на голову).
Расчет градиентов для задач SDE с заданными реализациями шума в настоящее время не поддерживается, или я просто не выполняю соответствующие вызовы функций? Я мог легко поверить в последнее, так как это была небольшая борьба, чтобы добраться до точки, где работали рабочие части приведенного выше кода, но я не мог найти никакой подсказки относительно того, что я неправильно предоставил после прохождения этого код с помощью отладчика Juno.