diff --git a/src/explorers/AAPS.jl b/src/explorers/AAPS.jl index 38e9f307e..b0332ba2c 100644 --- a/src/explorers/AAPS.jl +++ b/src/explorers/AAPS.jl @@ -18,24 +18,9 @@ of -log pi(x). The tuning parameter `K` defines the number of segments to explor """ Base.@kwdef struct AAPS{T,TPrec <: Preconditioner} """ - Reference to the leapfrog step size. + The leapfrog step size. """ - step_size_ref::Base.RefValue{Float64} = Ref(1.0) - - """ - Log of the lower bound on the acceptance probability used for adapting the step size. - """ - adapt_log_lower_bound::Float64 = -0.001 # smallest value with which I can make stan_banana(1) work - - """ - Log of the upper bound on the acceptance probability used for adapting the step size. - """ - adapt_log_upper_bound::Float64 = Inf64 # don't increase the step size - - """ - Only adapt the step size during the first `adapt_until_round` rounds. - """ - adapt_until_round::Int64 = typemax(Int64) # stan_banana(1) won't work otherwise + step_size::Float64 = 1.0 """ Maximum number of segments (regions between apogees) to explore. @@ -63,38 +48,11 @@ function adapt_explorer(explorer::AAPS, reduced_recorders, current_pt, new_tempe estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders) # TODO: adapt K return AAPS( - explorer.step_size_ref, explorer.adapt_log_lower_bound, - explorer.adapt_log_upper_bound, explorer.adapt_until_round, - explorer.K, explorer.default_autodiff_backend, + explorer.step_size, explorer.K, explorer.default_autodiff_backend, explorer.preconditioner, estimated_target_std_deviations ) end -# uses the autoMALA internal `auto_step_size` to find a step size -function find_reasonable_step_size( - explorer::AAPS, - replica, - target_log_potential, - state::AbstractVector - ) - recorders = replica.recorders - dim = length(state) - temp_position = get_buffer(recorders.buffers, :aaps_fwd_position_buffer, dim) - temp_velocity = get_buffer(recorders.buffers, :aaps_fwd_velocity_buffer, dim) - temp_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim) - temp_position .= state - randn!(replica.rng, temp_velocity) - build_preconditioner!( - temp_precond, explorer.preconditioner, replica.rng, explorer.estimated_target_std_deviations - ) - old_step_size = explorer.step_size_ref[] - exponent = auto_step_size( - target_log_potential, temp_precond, temp_position, temp_velocity, - recorders, replica.chain, old_step_size, - explorer.adapt_log_lower_bound, explorer.adapt_log_upper_bound) - return old_step_size * (2.0^exponent) -end - #= Extract info common to all types of target and perform a step!() =# @@ -102,16 +60,6 @@ function _extract_commons_and_run!(explorer::AAPS, replica, shared, log_potentia log_potential_autodiff = ADgradient( explorer.default_autodiff_backend, log_potential, replica.recorders.buffers ) - # TODO: if allowed for more replicas, all of them would write to the shares ref - # need to move this elsewhere where there's only one process acting. but where? - if shared.iterators.scan == 1 && shared.iterators.round <= explorer.adapt_until_round - if n_chains(shared.tempering) == 1 - explorer.step_size_ref[] = find_reasonable_step_size( - explorer, replica, log_potential_autodiff, state) - else - @warn "Step-size adaptation for more than 1 chain is unsupported. Skipping." maxlog=1 - end - end aaps!( replica.rng, explorer, @@ -156,52 +104,54 @@ function aaps!( # get buffers dim = length(position) - diag_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim) + init_position = get_buffer(recorders.buffers, :aaps_init_position, dim) + diag_precond = get_buffer(recorders.buffers, :aaps_diag_precond, dim) fwd_state, bwd_state = get_fwd_bwd_states(recorders.buffers, dim) # initialize + init_position .= position # store initial position in case we need to bail due to divergent transition build_preconditioner!( diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations ) - copyto!(fwd_state.position, position) - copyto!(bwd_state.position, position) # start bwd at same position -> requires skipping - randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2) + fwd_state.position .= position + bwd_state.position .= position # start bwd at same position -> requires skipping + randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2) bwd_state.velocity .= -1 .* fwd_state.velocity # find the initial segment by moving forward and backward - fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond) - bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state + fwd_wmax,valid = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond) + !valid && return # bail + bwd_wmax,valid = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state + !valid && return # bail # update the Gumbel-max-trick decision if fwd_wmax > bwd_wmax wmax = fwd_wmax - copyto!(position, fwd_state.max_position) + position .= fwd_state.max_position else wmax = bwd_wmax - copyto!(position, bwd_state.max_position) + position .= bwd_state.max_position end # sample segments by continuing from the previous endpoints # note that K+1 segments are sampled in total, as in the original AAPS implementation # see https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L311 for _ in 1:explorer.K - if rand(rng, Bool) # extend forward trajectory. avoids specifying in advance how many times we move forward/backward - fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond) - if fwd_wmax > wmax - wmax = fwd_wmax - copyto!(position, fwd_state.max_position) - end - else - bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond) - if bwd_wmax > wmax - wmax = bwd_wmax - copyto!(position, bwd_state.max_position) - end + aaps_state = rand(rng, Bool) ? fwd_state : bwd_state # avoids specifying in advance how many times we move forward/backward + new_wmax, valid = sample_segment!(explorer, aaps_state, target_log_potential, rng, diag_precond) + !valid && break + if new_wmax > wmax + wmax = new_wmax + position .= aaps_state.max_position end end - # w(z,z') = exp(log_joint) => proposal always accepted + if !valid # bail on the full path + position .= init_position + # else + # note: using w(z,z') = exp(log_joint) => proposal always accepted # no need to update position, we work in place # TODO: accept/reject if other proposal is used + end end """ @@ -215,7 +165,6 @@ function sample_segment!( diag_precond::Vector; skip_first::Bool = false # avoid double counting starting state. same as try0 in https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L268 ) - step_size = explorer.step_size_ref[] logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond) (isnan(logp) || isinf(logp)) && error(""" sample_segment!: invalid initial density (logp=$logp). @@ -234,18 +183,13 @@ function sample_segment!( # hence, p^T M^{-1} gradU > 0 ⟺ v^T cgrad < 0, and viceversa old_sign = sign(dot(state.velocity, cgrad)) while true - leap_frog!( + valid = leap_frog!( target_log_potential, diag_precond, state.position, state.velocity, - step_size) + explorer.step_size) + !valid && return (wmax,false) logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond) - - (isnan(logp) || isinf(logp)) && error(""" - sample_segment!: invalid density (logp=$logp). - Try decreasing the step size (got step_size=$step_size) - """) - new_sign = sign(dot(state.velocity, cgrad)) - old_sign < 0 && new_sign > 0 && return wmax + old_sign < 0 && new_sign > 0 && return (wmax,true) old_sign = new_sign ljoint = log_joint(logp, state.velocity) w = ljoint + rand(rng, Gumbel()) diff --git a/test/test_AAPS.jl b/test/test_AAPS.jl index 2054b05d9..072080cad 100644 --- a/test/test_AAPS.jl +++ b/test/test_AAPS.jl @@ -1,11 +1,10 @@ using MCMCChains if !is_windows_in_CI() -@testset "AAPS" begin - rng = SplittableRandom(1) - pt = pigeons(; - target = Pigeons.stan_banana(1), - explorer = AAPS(), - n_chains = 1, n_rounds = 12, record = [traces]) - @show min_ess_id = minimum(ess(Chains(sample_array(pt))).nt.ess) -end + @testset "AAPS" begin + pt = pigeons(; + target = Pigeons.stan_banana(1), + explorer = AAPS(step_size = 2. ^(-4)), + n_chains = 1, n_rounds = 12, record = [traces]) + @test abs(23-minimum(ess(Chains(sample_array(pt))).nt.ess)) < 1 + end end