Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow usage of AbstractSampler #2008

Merged
merged 22 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5136d05
initial work on allowing AdvancedHMC samplers
torfjelde Jun 13, 2023
6f4f328
simplify the hacky initialize_nuts method
torfjelde Jun 13, 2023
0c711ed
slight generalization
torfjelde Jun 13, 2023
1a13d50
remove unnecessary type constraint
torfjelde Jun 13, 2023
ed2077b
rever changes to sample overloads
torfjelde Jun 18, 2023
8a489f6
use a subtype of InferenceAlgorithm to wrap any sampler
torfjelde Jun 18, 2023
59ac28b
improve usage of SamplerWrapper
torfjelde Jun 18, 2023
008a853
renamed hmc_new.jl to something a bit more indicative
torfjelde Jun 18, 2023
8f698dc
added support for AdvancedMH
torfjelde Jun 18, 2023
817867c
Merge branch 'master' into torfjelde/allow-abstractsampler-draft
torfjelde Jun 18, 2023
2b73181
forgot to change include
torfjelde Jun 18, 2023
5ef1fa8
renamed SamplerWrapper to ExternalSampler and provided a function ext…
torfjelde Jun 20, 2023
82ab311
added tests for Advanced{HMC,MH}
torfjelde Jun 20, 2023
761ff45
Merge branch 'master' into torfjelde/allow-abstractsampler-draft
torfjelde Jun 20, 2023
a1fabca
Merge branch 'master' into torfjelde/allow-abstractsampler-draft
torfjelde Jun 21, 2023
335e868
fixed external tests
torfjelde Jun 21, 2023
d1afddd
change target acceptance rate
torfjelde Jun 26, 2023
6064834
fixed optim tests
torfjelde Jun 27, 2023
63e37f5
remove NelderMead from tests
torfjelde Jun 27, 2023
22cdfeb
allow models with one variance parameter per observation to fail MLE …
torfjelde Jun 27, 2023
b08dd82
no tests (#2028)
JaimeRZP Jul 4, 2023
b0503e5
Merge branch 'master' into torfjelde/allow-abstractsampler-draft
yebai Jul 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export @model, # modelling
resume,
@logprob_str,
@prob_str,
externalsampler,

setchunksize, # helper
setadbackend,
Expand Down
87 changes: 87 additions & 0 deletions src/contrib/inference/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
struct TuringState{S,F}
state::S
logdensity::F
end

struct TuringTransition{T,NT<:NamedTuple,F<:AbstractFloat}
θ::T
lp::F
stat::NT
end

function TuringTransition(vi::AbstractVarInfo, t)
theta = tonamedtuple(vi)
lp = getlogp(vi)
return TuringTransition(theta, lp, getstats(t))
end

metadata(t::TuringTransition) = merge((lp = t.lp,), t.stat)
DynamicPPL.getlogp(t::TuringTransition) = t.lp

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
θ = getparams(transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
# TODO: `deepcopy` is overkill; make more efficient.
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
return TuringTransition(varinfo, transition)
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(transition::AdvancedHMC.Transition) = transition.z.θ
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(transition::AdvancedMH.Transition) = transition.params
getstats(transition) = NamedTuple()
Comment on lines +6 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde Isn't this obsolete now that #2026 was merged?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. @JaimeRZP, can you do a follow-up PR to unify these Transition types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, which I was aware of; me and Jaime wanted to wait until that had been merged before merging this PR.
I was planning to incorporate those changes into this PR before merging. We're also missing a version-bump.

I'd appreciate it if we left merging of a PR to the person who opened it, unless otherwise explicitly stated. In particular now when it's just a matter of days before I'll be back in full development capacity again. This has happened quite few times now :/

Also, it's not like this PR needs to be merged to be able to develop other functionality; it's easy enough to just depend on the branch directly.

Copy link
Member

@yebai yebai Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that @JaimeRZP needs this to be released. Sorry for the rush -- hopefully, it didn't break anything! We can always add more changes in a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. @JaimeRZP you can always just do ]add Turing#torfjelde/allow-abstractsampler-draft if you want to try out recent developments. And if you want to develop features based on this branch, just create a branch based on this PR and then continue from there, as you did with #2028 :)


getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Setfield.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
kwargs...
)
sampler = sampler_wrapper.alg.sampler

# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model))

# Link the varinfo.
f = setvarinfo(f, DynamicPPL.link!!(getvarinfo(f), model))

# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler; kwargs...
)

# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity

# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
)

# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end
25 changes: 22 additions & 3 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using DynamicPPL
using AbstractMCMC: AbstractModel, AbstractSampler
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using Setfield: Setfield

import AbstractMCMC
import AdvancedHMC; const AHMC = AdvancedHMC
Expand Down Expand Up @@ -66,7 +67,8 @@ export InferenceAlgorithm,
dot_observe,
resume,
predict,
isgibbscomponent
isgibbscomponent,
externalsampler

#######################
# Sampler abstraction #
Expand All @@ -77,9 +79,26 @@ abstract type ParticleInference <: InferenceAlgorithm end
abstract type Hamiltonian{AD} <: InferenceAlgorithm end
abstract type StaticHamiltonian{AD} <: Hamiltonian{AD} end
abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end

getADbackend(::Hamiltonian{AD}) where AD = AD()

"""
ExternalSampler{S<:AbstractSampler}

# Fields
$(TYPEDFIELDS)
"""
struct ExternalSampler{S<:AbstractSampler} <: InferenceAlgorithm
"the sampler to wrap"
sampler::S
end

"""
externalsampler(sampler::AbstractSampler)

Wrap a sampler so it can be used as an inference algorithm.
"""
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)

# Algorithm for sampling from the prior
struct Prior <: InferenceAlgorithm end

Expand Down Expand Up @@ -246,7 +265,6 @@ function AbstractMCMC.sample(
return AbstractMCMC.sample(rng, model, SampleFromPrior(), ensemble, N, n_chains;
chain_type=chain_type, progress=progress, kwargs...)
end

##########################
# Chain making utilities #
##########################
Expand Down Expand Up @@ -442,6 +460,7 @@ include("gibbs_conditional.jl")
include("gibbs.jl")
include("../contrib/inference/sghmc.jl")
include("emcee.jl")
include("../contrib/inference/abstractmcmc.jl")

################
# Typing tools #
Expand Down
75 changes: 75 additions & 0 deletions test/contrib/inference/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using Turing.Inference: AdvancedHMC

function initialize_nuts(model::Turing.Model)
# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model))

# Link the varinfo.
f = Turing.Inference.setvarinfo(f, DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model))

# Choose parameter dimensionality and initial parameter value
D = LogDensityProblems.dimension(f)
initial_θ = rand(D) .- 0.5

# Define a Hamiltonian system
metric = AdvancedHMC.DiagEuclideanMetric(D)
hamiltonian = AdvancedHMC.Hamiltonian(metric, f)

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = AdvancedHMC.find_good_stepsize(hamiltonian, initial_θ)
integrator = AdvancedHMC.Leapfrog(initial_ϵ)

# Define an HMC sampler, with the following components
# - multinomial sampling scheme,
# - generalised No-U-Turn criteria, and
# - windowed adaption for step-size and diagonal mass matrix
proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS,AdvancedHMC.GeneralisedNoUTurn}(integrator)
adaptor = AdvancedHMC.StanHMCAdaptor(
AdvancedHMC.MassMatrixAdaptor(metric),
AdvancedHMC.StepSizeAdaptor(0.65, integrator)
)

return AdvancedHMC.HMCSampler(proposal, metric, adaptor)
end


function initialize_mh(model)
f = DynamicPPL.LogDensityFunction(model)
d = LogDensityProblems.dimension(f)
return AdvancedMH.RWMH(MvNormal(Zeros(d), 0.1 * I))
end

@testset "External samplers" begin
@testset "AdvancedHMC.jl" begin
for model in DynamicPPL.TestUtils.DEMO_MODELS
# Need some functionality to initialize the sampler.
# TODO: Remove this once the constructors in the respective packages become "lazy".
sampler = initialize_nuts(model);
DynamicPPL.TestUtils.test_sampler(
[model],
DynamicPPL.Sampler(externalsampler(sampler), model),
5_000;
nadapts=1_000,
discard_initial=1_000,
rtol=0.2
)
end
end

@testset "AdvancedMH.jl" begin
for model in DynamicPPL.TestUtils.DEMO_MODELS
# Need some functionality to initialize the sampler.
# TODO: Remove this once the constructors in the respective packages become "lazy".
sampler = initialize_mh(model);
DynamicPPL.TestUtils.test_sampler(
[model],
DynamicPPL.Sampler(externalsampler(sampler), model),
10_000;
discard_initial=1_000,
thinning=10,
rtol=0.2
)
end
end
end
31 changes: 26 additions & 5 deletions test/modes/OptimInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ end
@testset "MAP for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
result_true = posterior_optima(model)

@testset "$(optimizer)" for optimizer in [LBFGS(), NelderMead()]
@testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(), NelderMead()]
result = optimize(model, MAP(), optimizer)
vals = result.values

Expand All @@ -170,21 +170,42 @@ end
end
end
end


# Some of the models have one variance parameter per observation, and so
# the MLE should have the variances set to 0. Since we're working in
# transformed space, this corresponds to `-Inf`, which is of course not achievable.
# In particular, it can result in "early termniation" of the optimization process
# because we hit NaNs, etc. To avoid this, we set the `g_tol` and the `f_tol` to
# something larger than the default.
allowed_incorrect_mle = [
DynamicPPL.TestUtils.demo_dot_assume_dot_observe,
DynamicPPL.TestUtils.demo_assume_index_observe,
DynamicPPL.TestUtils.demo_assume_multivariate_observe,
DynamicPPL.TestUtils.demo_assume_observe_literal,
DynamicPPL.TestUtils.demo_dot_assume_observe_submodel,
DynamicPPL.TestUtils.demo_dot_assume_dot_observe_matrix,
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix,
]
@testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
result_true = likelihood_optima(model)

# `NelderMead` seems to struggle with convergence here, so we exclude it.
@testset "$(optimizer)" for optimizer in [LBFGS(),]
result = optimize(model, MLE(), optimizer)
@testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(),]
result = optimize(model, MLE(), optimizer, Optim.Options(g_tol=1e-3, f_tol=1e-3))
vals = result.values

for vn in DynamicPPL.TestUtils.varnames(model)
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
@test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol=0.05
if model.f in allowed_incorrect_mle
@test isfinite(get(result_true, vn_leaf))
else
@test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol=0.05
end
end
end
end
end
end
end

# Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($
@timeit_include("inference/Inference.jl")
@timeit_include("contrib/inference/dynamichmc.jl")
@timeit_include("contrib/inference/sghmc.jl")
@timeit_include("contrib/inference/abstractmcmc.jl")
@timeit_include("inference/mh.jl")
end
end
Expand Down