Easiest way to programatically set initial parameters based on the mean of the priors #2394
-
This is somewhat of a follow-up on #2286 I would like to set my initial parameters to the mean (or the mode) of the prior distributions. using Turing, StatsFuns, DataFrames, Random
@model function mod(y, x1, x2)
μ_intercept ~ Normal(0, 0.5)
μ_x1 ~ Normal(1, 0.5)
μ_x2 ~ Normal(2, 0.5)
σ ~ truncated(Normal(0.0, 1), lower=0)
μ = μ_intercept .+ μ_x1 .* x1 .+ μ_x2 .* x2
y .~ Normal.(μ, σ)
end The most optimal way in theory, as far as I can see, would be to have a prior extractor function that extracts the distributions used from a model. This would be wonderful for many other applications, including visualization of models, but in my case it would allow to then compute the initial parameters analytically: # Ideal solution: extract priors as a collection of distributions
priors = (
μ_intercept=Normal(0, 0.5),
μ_x1=Normal(1, 0.5),
μ_x2=Normal(2, 0.5),
σ=truncated(Normal(0.0, 1), lower=0)
) # Assuming this was obtained with something like extract_priors()
initial_parameters = mean.(collect(priors)) In the issue mentioned above, the alternative proposed was to first sample from the priors, and then compute the indices of interest. However, two approaches were suggested.
# Existing solution
fit = mod(rand(10), rand(10), rand(10))
priors_samples = hcat([rand(Vector, fit) for _ in 1:100]...)
initial_parameters = mean(priors_samples, dims=2)
# Alternative
fit = mod(rand(10), rand(10), rand(10))
chains = sample(fit, Prior(), 100)
initial_parameters = mean(Array(chains), dims=1) Both of these solutions have some overhead as one must fit the model first. Additionally, using My question is, what would be the fastest way and the most "Turingy" approach? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
We actually have a method for extracting the priors: So I think the following should do the trick: mapreduce(DynamicPPL.tovec ∘ mean, vcat, values(extract_priors(model))) This does the following:
|
Beta Was this translation helpful? Give feedback.
Ah yes, so you have to bump Turing to 0.35 (which should also update DynamicPPL.jl to the most recent version). Then this should work 👍