Skip to content

Commit

Permalink
Merge pull request #146 from Julia-Tempering/aaps-fix-autostep
Browse files Browse the repository at this point in the history
aaps: bail on divergences and remove adapt
  • Loading branch information
miguelbiron authored Oct 3, 2023
2 parents 28b4346 + 10f225d commit 220fda0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 94 deletions.
116 changes: 30 additions & 86 deletions src/explorers/AAPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,9 @@ of -log pi(x). The tuning parameter `K` defines the number of segments to explor
"""
Base.@kwdef struct AAPS{T,TPrec <: Preconditioner}
"""
Reference to the leapfrog step size.
The leapfrog step size.
"""
step_size_ref::Base.RefValue{Float64} = Ref(1.0)

"""
Log of the lower bound on the acceptance probability used for adapting the step size.
"""
adapt_log_lower_bound::Float64 = -0.001 # smallest value with which I can make stan_banana(1) work

"""
Log of the upper bound on the acceptance probability used for adapting the step size.
"""
adapt_log_upper_bound::Float64 = Inf64 # don't increase the step size

"""
Only adapt the step size during the first `adapt_until_round` rounds.
"""
adapt_until_round::Int64 = typemax(Int64) # stan_banana(1) won't work otherwise
step_size::Float64 = 1.0

"""
Maximum number of segments (regions between apogees) to explore.
Expand Down Expand Up @@ -63,55 +48,18 @@ function adapt_explorer(explorer::AAPS, reduced_recorders, current_pt, new_tempe
estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders)
# TODO: adapt K
return AAPS(
explorer.step_size_ref, explorer.adapt_log_lower_bound,
explorer.adapt_log_upper_bound, explorer.adapt_until_round,
explorer.K, explorer.default_autodiff_backend,
explorer.step_size, explorer.K, explorer.default_autodiff_backend,
explorer.preconditioner, estimated_target_std_deviations
)
end

# uses the autoMALA internal `auto_step_size` to find a step size
function find_reasonable_step_size(
explorer::AAPS,
replica,
target_log_potential,
state::AbstractVector
)
recorders = replica.recorders
dim = length(state)
temp_position = get_buffer(recorders.buffers, :aaps_fwd_position_buffer, dim)
temp_velocity = get_buffer(recorders.buffers, :aaps_fwd_velocity_buffer, dim)
temp_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim)
temp_position .= state
randn!(replica.rng, temp_velocity)
build_preconditioner!(
temp_precond, explorer.preconditioner, replica.rng, explorer.estimated_target_std_deviations
)
old_step_size = explorer.step_size_ref[]
exponent = auto_step_size(
target_log_potential, temp_precond, temp_position, temp_velocity,
recorders, replica.chain, old_step_size,
explorer.adapt_log_lower_bound, explorer.adapt_log_upper_bound)
return old_step_size * (2.0^exponent)
end

#=
Extract info common to all types of target and perform a step!()
=#
function _extract_commons_and_run!(explorer::AAPS, replica, shared, log_potential, state::AbstractVector)
log_potential_autodiff = ADgradient(
explorer.default_autodiff_backend, log_potential, replica.recorders.buffers
)
# TODO: if allowed for more replicas, all of them would write to the shares ref
# need to move this elsewhere where there's only one process acting. but where?
if shared.iterators.scan == 1 && shared.iterators.round <= explorer.adapt_until_round
if n_chains(shared.tempering) == 1
explorer.step_size_ref[] = find_reasonable_step_size(
explorer, replica, log_potential_autodiff, state)
else
@warn "Step-size adaptation for more than 1 chain is unsupported. Skipping." maxlog=1
end
end
aaps!(
replica.rng,
explorer,
Expand Down Expand Up @@ -156,52 +104,54 @@ function aaps!(

# get buffers
dim = length(position)
diag_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim)
init_position = get_buffer(recorders.buffers, :aaps_init_position, dim)
diag_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim)
fwd_state, bwd_state = get_fwd_bwd_states(recorders.buffers, dim)

# initialize
init_position .= position # store initial position in case we need to bail due to divergent transition
build_preconditioner!(
diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations
)
copyto!(fwd_state.position, position)
copyto!(bwd_state.position, position) # start bwd at same position -> requires skipping
randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2)
fwd_state.position .= position
bwd_state.position .= position # start bwd at same position -> requires skipping
randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2)
bwd_state.velocity .= -1 .* fwd_state.velocity

# find the initial segment by moving forward and backward
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond)
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state
fwd_wmax,valid = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond)
!valid && return # bail
bwd_wmax,valid = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state
!valid && return # bail

# update the Gumbel-max-trick decision
if fwd_wmax > bwd_wmax
wmax = fwd_wmax
copyto!(position, fwd_state.max_position)
position .= fwd_state.max_position
else
wmax = bwd_wmax
copyto!(position, bwd_state.max_position)
position .= bwd_state.max_position
end

# sample segments by continuing from the previous endpoints
# note that K+1 segments are sampled in total, as in the original AAPS implementation
# see https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L311
for _ in 1:explorer.K
if rand(rng, Bool) # extend forward trajectory. avoids specifying in advance how many times we move forward/backward
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond)
if fwd_wmax > wmax
wmax = fwd_wmax
copyto!(position, fwd_state.max_position)
end
else
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond)
if bwd_wmax > wmax
wmax = bwd_wmax
copyto!(position, bwd_state.max_position)
end
aaps_state = rand(rng, Bool) ? fwd_state : bwd_state # avoids specifying in advance how many times we move forward/backward
new_wmax, valid = sample_segment!(explorer, aaps_state, target_log_potential, rng, diag_precond)
!valid && break
if new_wmax > wmax
wmax = new_wmax
position .= aaps_state.max_position
end
end
# w(z,z') = exp(log_joint) => proposal always accepted
if !valid # bail on the full path
position .= init_position
# else
# note: using w(z,z') = exp(log_joint) => proposal always accepted
# no need to update position, we work in place
# TODO: accept/reject if other proposal is used
end
end

"""
Expand All @@ -215,7 +165,6 @@ function sample_segment!(
diag_precond::Vector;
skip_first::Bool = false # avoid double counting starting state. same as try0 in https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L268
)
step_size = explorer.step_size_ref[]
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond)
(isnan(logp) || isinf(logp)) && error("""
sample_segment!: invalid initial density (logp=$logp).
Expand All @@ -234,18 +183,13 @@ function sample_segment!(
# hence, p^T M^{-1} gradU > 0 ⟺ v^T cgrad < 0, and viceversa
old_sign = sign(dot(state.velocity, cgrad))
while true
leap_frog!(
valid = leap_frog!(
target_log_potential, diag_precond, state.position, state.velocity,
step_size)
explorer.step_size)
!valid && return (wmax,false)
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond)

(isnan(logp) || isinf(logp)) && error("""
sample_segment!: invalid density (logp=$logp).
Try decreasing the step size (got step_size=$step_size)
""")

new_sign = sign(dot(state.velocity, cgrad))
old_sign < 0 && new_sign > 0 && return wmax
old_sign < 0 && new_sign > 0 && return (wmax,true)
old_sign = new_sign
ljoint = log_joint(logp, state.velocity)
w = ljoint + rand(rng, Gumbel())
Expand Down
15 changes: 7 additions & 8 deletions test/test_AAPS.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
using MCMCChains
if !is_windows_in_CI()
@testset "AAPS" begin
rng = SplittableRandom(1)
pt = pigeons(;
target = Pigeons.stan_banana(1),
explorer = AAPS(),
n_chains = 1, n_rounds = 12, record = [traces])
@show min_ess_id = minimum(ess(Chains(sample_array(pt))).nt.ess)
end
@testset "AAPS" begin
pt = pigeons(;
target = Pigeons.stan_banana(1),
explorer = AAPS(step_size = 2. ^(-4)),
n_chains = 1, n_rounds = 12, record = [traces])
@test abs(23-minimum(ess(Chains(sample_array(pt))).nt.ess)) < 1
end
end

0 comments on commit 220fda0

Please sign in to comment.