-
Notifications
You must be signed in to change notification settings - Fork 0
/
WarmupHMCDynamicHMCext.jl
108 lines (96 loc) · 5.52 KB
/
WarmupHMCDynamicHMCext.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
module WarmupHMCDynamicHMCext
using WarmupHMC, DynamicHMC, UnPack, Random
import WarmupHMC: reparametrize, find_reparametrization, mcmc_with_reparametrization
import DynamicHMC: default_warmup_stages, default_reporter, NUTS, SamplingLogDensity, _warmup, mcmc, WarmupState, initialize_warmup_state, warmup, InitialStepsizeSearch, TuningNUTS, _empty_posterior_matrix, TreeStatisticsNUTS, Hamiltonian, initial_adaptation_state, make_mcmc_reporter, evaluate_ℓ, current_ϵ, sample_tree, adapt_stepsize, report, REPORT_SIGDIGITS, GaussianKineticEnergy, regularize_M⁻¹, sample_M⁻¹, final_ϵ, mcmc_steps, mcmc_next_step
function mcmc_with_reparametrization(rng, ℓ, N; initialization = (),
warmup_stages = default_warmup_stages(),
algorithm = NUTS(), reporter = default_reporter())
@unpack final_reparametrization_state, inference =
mcmc_keep_reparametrization(rng, ℓ, N; initialization = initialization,
warmup_stages = warmup_stages, algorithm = algorithm,
reporter = reporter)
final_warmup_state = final_reparametrization_state.warmup_state
@unpack κ, ϵ = final_warmup_state
(; inference..., κ, ϵ)
end
function mcmc_keep_reparametrization(rng::AbstractRNG, ℓ, N::Integer;
initialization = (),
warmup_stages = default_warmup_stages(),
algorithm = NUTS(),
reporter = default_reporter())
sampling_logdensity = SamplingLogDensity(rng, ℓ, algorithm, reporter)
initial_reparametrization_state = initialize_reparametrization_state(rng, ℓ; initialization...)
warmup, reparametrization_state = _warmup(sampling_logdensity, warmup_stages, initial_reparametrization_state)
inference = mcmc(sampling_logdensity, N, reparametrization_state)
(; initial_reparametrization_state, warmup, final_reparametrization_state = reparametrization_state, inference,
sampling_logdensity)
end
struct ReparametrizationState{R,W<:WarmupState}
reparametrization::R
warmup_state::W
end
function initialize_reparametrization_state(rng, ℓ; kwargs...)
ReparametrizationState(
ℓ,
initialize_warmup_state(rng, ℓ; kwargs...)
)
end
function warmup(sampling_logdensity, stage::Nothing, reparametrization_state::ReparametrizationState)
@unpack reparametrization, warmup_state = reparametrization_state
w, warmup_state = warmup(sampling_logdensity, stage, warmup_state)
return w, ReparametrizationState(reparametrization, warmup_state)
end
function warmup(sampling_logdensity, stage::InitialStepsizeSearch, reparametrization_state::ReparametrizationState)
@unpack reparametrization, warmup_state = reparametrization_state
w, warmup_state = warmup(sampling_logdensity, stage, warmup_state)
return w, ReparametrizationState(reparametrization, warmup_state)
end
function warmup(sampling_logdensity, tuning::TuningNUTS{M}, reparametrization_state::ReparametrizationState) where {M}
@unpack rng, ℓ, algorithm, reporter = sampling_logdensity
@unpack reparametrization, warmup_state = reparametrization_state
@unpack Q, κ, ϵ = warmup_state
@unpack N, stepsize_adaptation, λ = tuning
posterior_matrix = _empty_posterior_matrix(Q, N)
tree_statistics = Vector{TreeStatisticsNUTS}(undef, N)
H = Hamiltonian(κ, reparametrization)
ϵ_state = initial_adaptation_state(stepsize_adaptation, ϵ)
ϵs = Vector{Float64}(undef, N)
mcmc_reporter = make_mcmc_reporter(reporter, N;
currently_warmup = true,
tuning = M ≡ Nothing ? "stepsize" : "stepsize and $(M) metric")
Q = evaluate_ℓ(reparametrization, reparametrize(ℓ, reparametrization, Q.q); strict = true)
for i in 1:N
ϵ = current_ϵ(ϵ_state)
ϵs[i] = ϵ
Q, stats = sample_tree(rng, algorithm, H, Q, ϵ)
posterior_matrix[:, i] = reparametrize(reparametrization, ℓ, Q.q)
tree_statistics[i] = stats
ϵ_state = adapt_stepsize(stepsize_adaptation, ϵ_state, stats.acceptance_rate)
report(mcmc_reporter, i; ϵ = round(ϵ; sigdigits = REPORT_SIGDIGITS))
end
Q = evaluate_ℓ(ℓ, reparametrize(reparametrization, ℓ, Q.q); strict = true)
if M ≢ Nothing
reparametrization = find_reparametrization(ℓ, posterior_matrix)
κ = GaussianKineticEnergy(regularize_M⁻¹(sample_M⁻¹(M, reparametrize(ℓ, reparametrization, posterior_matrix)), λ))
report(mcmc_reporter, "adaptation finished", adapted_kinetic_energy = κ)
end
((; posterior_matrix, tree_statistics, ϵs), ReparametrizationState(reparametrization, WarmupState(Q, κ, final_ϵ(ϵ_state))))
end
function mcmc(sampling_logdensity, N, reparametrization_state::ReparametrizationState)
@unpack rng, ℓ, algorithm, reporter = sampling_logdensity
@unpack reparametrization, warmup_state = reparametrization_state
sampling_logdensity = SamplingLogDensity(rng, reparametrization, algorithm, reporter)
@unpack Q = warmup_state
posterior_matrix = _empty_posterior_matrix(Q, N)
tree_statistics = Vector{TreeStatisticsNUTS}(undef, N)
mcmc_reporter = make_mcmc_reporter(reporter, N; currently_warmup = false)
steps = mcmc_steps(sampling_logdensity, warmup_state)
Q = evaluate_ℓ(reparametrization, reparametrize(ℓ, reparametrization, Q.q); strict = true)
for i in 1:N
Q, tree_statistics[i] = mcmc_next_step(steps, Q)
posterior_matrix[:, i] = reparametrize(reparametrization, ℓ, Q.q)
report(mcmc_reporter, i)
end
(; posterior_matrix, tree_statistics)
end
end