Skip to content

Commit

Permalink
Merge pull request #132 from Julia-Tempering/stan_banana
Browse files Browse the repository at this point in the history
Stan banana example
  • Loading branch information
miguelbiron authored Sep 14, 2023
2 parents 99a1644 + 562060b commit 5f1520d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
23 changes: 23 additions & 0 deletions examples/stan/banana.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// n-dimensional banana, as defined sec 3 of https://doi.org/10.1111/sjos.12532
// uses n1 = 2, n2 = dim => total dims = dim + 1
// easier values from eq 50 of https://arxiv.org/abs/2003.03636
// x ~ N(0, s_a) // s_a = sqrt(inv(2a)), a = 1/20 (2.5 easier)
// y_{1:d}|x ~ N(x^2, s_b) // s_b = sqrt(inv(2b)), b = 5 (50 easier)
data {
int<lower=1> dim;
}
transformed data {
real a, b, s_a, s_b;
a = inv(20);
b = 5.0;
s_a = sqrt(inv(2*a));
s_b = sqrt(inv(2*b));
}
parameters {
real x;
vector[dim] y;
}
model {
x ~ normal(0, s_a);
y ~ normal(square(x), s_b);
}
6 changes: 6 additions & 0 deletions src/targets/toy_stan_target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ stan_bernoulli(y = [0,1,0,0,0,0,0,0,0,1]) =
json(; y, N = length(y))
)

stan_banana(dim = 9) =
StanLogPotential(
stan_example_path("banana.stan"),
json(; dim)
)

observed_range_squared(x) = (maximum(x) - minimum(x))^2

# the centered one is the "harder" one, see https://mc-stan.org/users/documentation/case-studies/divergences_and_bias.html
Expand Down
1 change: 1 addition & 0 deletions test/test_stan.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testset "Stan examples" begin
pigeons(target = Pigeons.stan_eight_schools(true), n_rounds = 2, n_chains = 2)
pigeons(target = Pigeons.stan_eight_schools(false), n_rounds = 2, n_chains = 2)
pigeons(target = Pigeons.stan_banana(1), record = [online], n_chains = 1, n_rounds = 5, explorer = SliceSampler())

# some examples where an error is interpreted as -Inf:
pigeons(target = Pigeons.stan_funnel(1), record = [online], n_chains = 1, n_rounds = 5, explorer = SliceSampler())
Expand Down

0 comments on commit 5f1520d

Please sign in to comment.