Skip to content

Commit

Permalink
Merge pull request #133 from Julia-Tempering/fix-dimensionality-scaling
Browse files Browse the repository at this point in the history
Fix needed due to change in interface in AdvancedHMC
  • Loading branch information
alexandrebouchard authored Sep 20, 2023
2 parents d2c4ec8 + 6caeb61 commit d97ba35
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
80 changes: 55 additions & 25 deletions test/supporting/dimensional-analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,41 @@ using Random

Random.seed!(123)

abstract type LogDensity end
(p::LogDensity)(x) = LogDensityProblems.logdensity(p, x)
LogDensityProblems.dimension(p::T) where {T <: LogDensity} = p.dim
LogDensityProblems.capabilities(::Type{T}) where {T <: LogDensity} = LogDensityProblems.LogDensityOrder{1}()
Pigeons.initialization(p::LogDensity, _, _) = zeros(p.dim)


# Define the target distribution using the `LogDensityProblem` interface
struct LogTargetDensity
struct IsoNormal <: LogDensity
dim::Int
end
LogDensityProblems.logdensity(p::IsoNormal, θ) = -sum(abs2, θ) / 2 # standard multivariate normal

struct Funnel <: LogDensity
dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}()
function LogDensityProblems.logdensity(p::Funnel, z)
# z = (y, x[1], .., x[dim-1])
@assert length(z) == p.dim
sum = 0.0
y = z[1]
sum += logpdf(Normal(0.0, 3.0), y)
sigma_for_others = exp(y/2.0)
for i in 2:p.dim
sum += logpdf(Normal(0.0, sigma_for_others), z[i])
end
return sum
end

# Based off AdvancedHMC README:
function nuts(D)
function nuts(logp)
D = logp.dim
# Choose parameter dimensionality and initial parameter value
initial_θ = randn(D)
logp = LogTargetDensity(D)


# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000

Expand All @@ -46,13 +67,14 @@ function nuts(D)
# - multinomial sampling scheme,
# - generalised No-U-Turn criteria, and
# - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
# - `samples` will store the samples
# - `stats` will store diagnostic statistics for each sample
samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=false)
samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=false)


# next: compute logp on each sample
vs = map(s -> LogDensityProblems.logdensity(logp, s), samples)
Expand All @@ -63,9 +85,10 @@ function nuts(D)
return D, n_steps, ess_value
end

function single_chain_pigeons_mvn(D, explorer)
function single_chain_pigeons_mvn(logp, explorer)
pt = pigeons(;
target = toy_mvn_target(D),
target = logp,
reference = logp,
n_chains = 1,
seed = rand(Int),
show_report = false,
Expand All @@ -80,19 +103,21 @@ function single_chain_pigeons_mvn(D, explorer)
end


function auto_mala(D::Int)
function auto_mala(logp)
D = logp.dim
explorer = Pigeons.AutoMALA(exponent_n_refresh = 0.35)
n_steps, ess_value = single_chain_pigeons_mvn(D, explorer)
n_steps, ess_value = single_chain_pigeons_mvn(logp, explorer)
return D, n_steps, ess_value
end


sparse_slicer(D) = slicer(D, true)
dense_slicer(D) = slicer(D, false)
sparse_slicer(logp) = slicer(logp, true)
dense_slicer(logp) = slicer(logp, false)

function slicer(D, sparse::Bool)
function slicer(logp, sparse::Bool)
D = logp.dim
explorer = Pigeons.SliceSampler()
n_steps, ess_value = single_chain_pigeons_mvn(D, explorer)
n_steps, ess_value = single_chain_pigeons_mvn(logp, explorer)
return (sparse ? 1 : D), n_steps, ess_value
end

Expand All @@ -106,13 +131,14 @@ function compute_ess(vs)
end

function scaling_plot(
max,
max;
n_replicates = 1,
sampling_fcts = [
sparse_slicer,
dense_slicer,
nuts,
auto_mala])
auto_mala],
logp_type = IsoNormal)
cost_plot = plot()
ess_plot = plot()
data = Dict()
Expand All @@ -126,7 +152,8 @@ function scaling_plot(
ess = Float64[]
for i in 0:max
@show D = 2^i
@time replicates = [sampling_fct(D) for j in 1:n_replicates]
logp = logp_type(D)
@time replicates = [sampling_fct(logp) for j in 1:n_replicates]
push!(dims, D)

cost_and_ess = mean(
Expand Down Expand Up @@ -156,10 +183,10 @@ function scaling_plot(
data[sampler_symbol] = (; dims, costs)
end

filename_prefix = "benchmarks/scalings_nrep=$(n_replicates)_max=$max"
filename_prefix = "benchmarks/$logp_type/scalings_nrep=$(n_replicates)_max=$max"

slopes = Dict()
mkpath("benchmarks")
mkpath("benchmarks/$logp_type")
open("$filename_prefix.txt", "w") do io
for (k, v) in data
xs = log.(v.dims)
Expand All @@ -170,8 +197,11 @@ function scaling_plot(
end
end

# savefig(cost_plot, "$filename_prefix.pdf")
# savefig(ess_plot, "$(filename_prefix)_ess.pdf")
return (; logp_type, n_replicates, max, data, slopes, cost_plot, ess_plot)
end

return slopes, cost_plot, ess_plot
function save_dim_analysis_plots(tuple)
filename_prefix = "benchmarks/$(tuple.logp_type)/scalings_nrep=$(tuple.n_replicates)_max=$(tuple.max)"
savefig(tuple.cost_plot, "$filename_prefix.pdf")
savefig(tuple.ess_plot, "$(filename_prefix)_ess.pdf")
end
12 changes: 6 additions & 6 deletions test/test_auto_mala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ include("supporting/dimensional-analysis.jl")

mean_mh_accept(pt) = mean(Pigeons.explorer_mh_prs(pt))

auto_mala(target) =
automala(target) =
pigeons(;
target,
explorer = AutoMALA(),
n_chains = 1, n_rounds = 10, record = record_online())

@testset "Scaling law" begin
scalings, cost_plot, ess_plot = scaling_plot(10, 1, [auto_mala])
@test abs(scalings[:auto_mala] - 1.33) < 0.15
tuple = scaling_plot(7, sampling_fcts = [auto_mala])
@test abs(tuple.slopes[:auto_mala] - 1.33) < 0.15
end

@testset "Step size convergence" begin
Expand All @@ -23,8 +23,8 @@ end
end

@testset "Step size d-scaling" begin
step1d = auto_mala(toy_mvn_target(1)).shared.explorer.step_size
step1000d = auto_mala(toy_mvn_target(1000)).shared.explorer.step_size
step1d = automala(toy_mvn_target(1)).shared.explorer.step_size
step1000d = automala(toy_mvn_target(1000)).shared.explorer.step_size
@test step1000d < step1d # make sure we do shrink eps with d

# should not shrink by more than ~(1000)^(1/3) according to theory
Expand All @@ -42,7 +42,7 @@ end
@testset "AutoMALA dimensional autoscale" begin
for i in 0:3
d = 10^i
@test mean_mh_accept(auto_mala(toy_mvn_target(d))) > 0.4
@test mean_mh_accept(automala(toy_mvn_target(d))) > 0.4
end
end

Expand Down

0 comments on commit d97ba35

Please sign in to comment.