Stochasti c анализ чувствительности дифференциального уравнения с заданным шумом - PullRequest
0 голосов
/ 07 марта 2020

Я пытаюсь вычислить градиент функционала решения стохастического c дифференциального уравнения (SDE), учитывая конкретную c реализацию шума. Я могу успешно рассчитать эти градиенты, если я оставлю шум неопределенным, как показано в DiffEqFlux.jl: Использование других дифференциальных уравнений . Я также могу успешно получить решение для моей SDE для конкретной реализации c шума, как показано в DifferentialEquations.jl: Пример NoiseWrapper . Однако когда я пытаюсь соединить их вместе, код возвращает ошибку.

Вот минимальный рабочий пример, адаптированный из двух отдельных примеров, упомянутых выше:

using StochasticDiffEq, DiffEqBase, DiffEqNoiseProcess, DiffEqSensitivity, Zygote

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
function lotka_volterra_noise(du,u,p,t)
  du[1] = 0.1u[1]
  du[2] = 0.1u[2]
end
dt = 1//2^(4)
u0 = [1.0,1.0]
p = [2.2, 1.0, 2.0, 0.4]
prob1 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p)
sol1 = solve(prob1,EM(),dt=dt,save_noise=true)

W2 = NoiseWrapper(sol1.W)
prob2 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p,noise=W2)
sol2 = solve(prob2,EM(),dt=dt)

function predict_sde1(p)
  Array(concrete_solve(remake(prob1,p=p),EM(),dt=dt,sensealg=ForwardDiffSensitivity(),saveat=0.1))
end
loss_sde1(p) = sum(abs2,x-1 for x in predict_sde1(p))

loss_sde1(p)

# This gradient is successfully calculated
Zygote.gradient(loss_sde1,p)

function predict_sde2(p)
  W2 = NoiseWrapper(sol1.W)
  Array(concrete_solve(remake(prob2,p=p,noise=W2),EM(),dt=dt,sensealg=ForwardDiffSensitivity(),saveat=0.1))
end
loss_sde2(p) = sum(abs2,x-1 for x in predict_sde2(p))

# This loss is successfully calculated
loss_sde2(p)

# This gradient calculation raises and error
Zygote.gradient(loss_sde2,p)

Ошибка I получить в конце выполнения этот код

TypeError: in setfield!, expected Float64, got ForwardDiff.Dual{Nothing,Float64,4}

Stacktrace:
 [1] setproperty! at ./Base.jl:21 [inlined]
...

с последующим бесконечным выводом на трассировку стека (я могу опубликовать его, если вы считаете, что это будет полезно, но, поскольку он длиннее остальной части этого вопроса, я я бы не хотел загромождать вещи с ног на голову).

Расчет градиентов для задач SDE с заданными реализациями шума в настоящее время не поддерживается, или я просто не выполняю соответствующие вызовы функций? Я мог легко поверить в последнее, так как это была небольшая борьба, чтобы добраться до точки, где работали рабочие части приведенного выше кода, но я не мог найти никакой подсказки относительно того, что я неправильно предоставил после прохождения этого код с помощью отладчика Juno.

1 Ответ

2 голосов
/ 10 марта 2020

В качестве решения StackOverflow вы можете использовать ForwardDiffSensitivity(convert_tspan=false), чтобы обойти это. Рабочий код:

using StochasticDiffEq, DiffEqBase, DiffEqNoiseProcess, DiffEqSensitivity, Zygote

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
function lotka_volterra_noise(du,u,p,t)
  du[1] = 0.1u[1]
  du[2] = 0.1u[2]
end
dt = 1//2^(4)
u0 = [1.0,1.0]
p = [2.2, 1.0, 2.0, 0.4]
prob1 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p)
sol1 = solve(prob1,EM(),dt=dt,save_noise=true)

W2 = NoiseWrapper(sol1.W)
prob2 = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p,noise=W2)
sol2 = solve(prob2,EM(),dt=dt)

function predict_sde1(p)
  Array(concrete_solve(remake(prob1,p=p),EM(),dt=dt,sensealg=ForwardDiffSensitivity(convert_tspan=false),saveat=0.1))
end
loss_sde1(p) = sum(abs2,x-1 for x in predict_sde1(p))

loss_sde1(p)

# This gradient is successfully calculated
Zygote.gradient(loss_sde1,p)

function predict_sde2(p)
  Array(concrete_solve(prob2,EM(),prob2.u0,p,dt=dt,sensealg=ForwardDiffSensitivity(convert_tspan=false),saveat=0.1))
end
loss_sde2(p) = sum(abs2,x-1 for x in predict_sde2(p))

# This loss is successfully calculated
loss_sde2(p)

# This gradient calculation raises and error
Zygote.gradient(loss_sde2,p)

Как разработчик ... это не очень хорошее решение, и наш стандарт должен быть лучше. Я буду работать над этим. Вы можете отслеживать развитие здесь https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/issues/204. Вероятно, это будет решено через час или около того.

Редактировать: Исправление выпущено, и ваш оригинальный код работает.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...