Skip to content

Commit

Permalink
Merge pull request #153 from Julia-Tempering/banana-scale
Browse files Browse the repository at this point in the history
Add scale parameter to banana distribution
  • Loading branch information
nikola-sur authored Oct 8, 2023
2 parents 9d7e6e9 + 3a7a426 commit 02c1c14
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/stan/banana.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// y_{1:d}|x ~ N(x^2, s_b) // s_b = sqrt(inv(2b)), b = 5 (50 easier)
data {
int<lower=1> dim;
real<lower=0> scale;
}
transformed data {
real a, b, s_a, s_b;
Expand All @@ -19,5 +20,5 @@ parameters {
}
model {
x ~ normal(0, s_a);
y ~ normal(square(x), s_b);
y ~ normal(square(x), scale*s_b);
}
4 changes: 2 additions & 2 deletions src/targets/toy_stan_target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ stan_bernoulli(y = [0,1,0,0,0,0,0,0,0,1]) =
json(; y, N = length(y))
)

stan_banana(dim = 9) =
stan_banana(dim = 9, scale = 1.0) =
StanLogPotential(
stan_example_path("banana.stan"),
json(; dim)
json(; dim, scale)
)

observed_range_squared(x) = (maximum(x) - minimum(x))^2
Expand Down
2 changes: 1 addition & 1 deletion test/test_AAPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MCMCChains
if !is_windows_in_CI()
@testset "AAPS" begin
pt = pigeons(;
target = Pigeons.stan_banana(1),
target = Pigeons.stan_banana(1, 1.0),
explorer = AAPS(step_size = 2. ^(-4)),
n_chains = 1, n_rounds = 12, record = [traces])
@test abs(23-minimum(ess(Chains(sample_array(pt))).nt.ess)) < 1
Expand Down
2 changes: 1 addition & 1 deletion test/test_stan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if !is_windows_in_CI()
@testset "Stan examples" begin
pigeons(target = Pigeons.stan_eight_schools(true), n_rounds = 2, n_chains = 2)
pigeons(target = Pigeons.stan_eight_schools(false), n_rounds = 2, n_chains = 2)
pigeons(target = Pigeons.stan_banana(1), record = [online], n_chains = 1, n_rounds = 5, explorer = SliceSampler())
pigeons(target = Pigeons.stan_banana(1, 1.0), record = [online], n_chains = 1, n_rounds = 5, explorer = SliceSampler())

# some examples where an error is interpreted as -Inf:
pigeons(target = Pigeons.stan_funnel(1), record = [online], n_chains = 1, n_rounds = 5, explorer = SliceSampler())
Expand Down

0 comments on commit 02c1c14

Please sign in to comment.