Я пытался воспроизвести модель из PYMC3 и Stan сравнения .Но, похоже, он работает медленно, и когда я смотрю на @code_warntype
, я думаю, что есть некоторые вещи - K
и N
- которые компилятор называет Any
.
.добавление типов - хотя я не могу добавить типы к аргументам turing_model
, и в turing_model
все сложно, потому что он использует переменные autodiff, а не обычные.Я поместил весь код в функцию do_it
, чтобы избежать глобальных переменных, потому что они говорят, что глобальные переменные могут замедлять работу.(Хотя на самом деле это кажется медленнее.)
Есть какие-нибудь предложения относительно причины проблемы?Код turing_model
- это то, что повторяется, так что это должно иметь наибольшее значение.
using Turing, StatsPlots, Random
sigmoid(x) = 1.0 / (1.0 + exp(-x))
function scale(w0::Float64, w1::Array{Float64,1})
scale = √(w0^2 + sum(w1 .^ 2))
return w0 / scale, w1 ./ scale
end
function do_it(iterations::Int64)::Chains
K = 10 # predictor dimension
N = 1000 # number of data samples
X = rand(N, K) # predictors (1000, 10)
w1 = rand(K) # weights (10,)
w0 = -median(X * w1) # 50% of elements for each class (number)
w0, w1 = scale(w0, w1) # unit length (euclidean)
w_true = [w0, w1...]
y = (w0 .+ (X * w1)) .> 0.0 # labels
y = [Float64(x) for x in y]
σ = 5.0
σm = [x == y ? σ : 0.0 for x in 1:K, y in 1:K]
@model turing_model(X, y, σ, σm) = begin
w0_pred ~ Normal(0.0, σ)
w1_pred ~ MvNormal(σm)
p = sigmoid.(w0_pred .+ (X * w1_pred))
@inbounds for n in 1:length(y)
y[n] ~ Bernoulli(p[n])
end
end
@time chain = sample(turing_model(X, y, σ, σm), NUTS(iterations, 200, 0.65));
# ϵ = 0.5
# τ = 10
# @time chain = sample(turing_model(X, y, σ), HMC(iterations, ϵ, τ));
return (w_true=w_true, chains=chain::Chains)
end
chain = do_it(1000)