From 6ea2f189cf230f629f6812560292f35fd458ddaa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 08:18:41 +0000 Subject: [PATCH 1/6] added initial impl of ext for MarginalLogDensities.jl --- ext/TuringMarginalLogDensitiesExt.jl | 38 ++++++++++++++++++++++++++++ src/Turing.jl | 2 ++ 2 files changed, 40 insertions(+) create mode 100644 ext/TuringMarginalLogDensitiesExt.jl diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 0000000000..f4653da98e --- /dev/null +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,38 @@ +module TuringMarginalLogDensitiesExt + +using Turing: Turing, DynamicPPL +using MarginalLogDensities: MarginalLogDensities + + +# Use a struct for this to avoid closure overhead. +struct Drop2ndArgAndFlipSign{F} + f::F +end + +(f::Drop2ndArg)(x, _) = -f.f(x) + + +function Turing.marginalize( + model::DynamicPPL.Model, + varnames::Vector, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox() +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + varindices = DynamicPPL.getranges(varinfo, varnames) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure + # that the log-abs-det jacobian terms are not included. + context = Turing.Optimisation.OptimizationContext(DynamicPPL.leafcontext(model.context)) + varinfo_linked = DynamicPPL.link(varinfo, model) + f = Base.Fix1( + LogDensityProblems.logdensity, + DynamicPPL.LogDensityFunction(varinfo_linked, model, context) + ) + # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which + # represent the _negative_ log-density. + mdl = MarginalLogDensity( + Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method + ) + return mdl +end diff --git a/src/Turing.jl b/src/Turing.jl index dbfd5c5cf0..1296678e9a 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -58,6 +58,8 @@ using .Optimisation include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release +include("extensions.jl") + ########### # Exports # ########### From 16ed43d6e7d8c8f8f0c1d7f291681b636c7bb0b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 08:38:19 +0000 Subject: [PATCH 2/6] set up the extension properly + fixed a bug --- Project.toml | 3 +++ ext/TuringMarginalLogDensitiesExt.jl | 13 +++++++------ src/Turing.jl | 1 + 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 2bcc8bee57..8e09ee812b 100644 --- a/Project.toml +++ b/Project.toml @@ -41,10 +41,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" +TuringMarginalLogDensitiesExt = "MarginalLogDensities" TuringOptimExt = "Optim" [compat] @@ -70,6 +72,7 @@ Libtask = "0.8.8" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" +MarginalLogDensities = "0.3" MCMCChains = "5, 6" NamedArrays = "0.9, 0.10" Optim = "1" diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl index f4653da98e..886b3a9e43 100644 --- a/ext/TuringMarginalLogDensitiesExt.jl +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -1,21 +1,20 @@ module TuringMarginalLogDensitiesExt using Turing: Turing, DynamicPPL +using Turing.Inference: LogDensityProblems using MarginalLogDensities: MarginalLogDensities - # Use a struct for this to avoid closure overhead. struct Drop2ndArgAndFlipSign{F} f::F end -(f::Drop2ndArg)(x, _) = -f.f(x) - +(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x) function Turing.marginalize( model::DynamicPPL.Model, varnames::Vector, - method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox() + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), ) # Determine the indices for the variables to marginalise out. varinfo = DynamicPPL.typed_varinfo(model) @@ -27,12 +26,14 @@ function Turing.marginalize( varinfo_linked = DynamicPPL.link(varinfo, model) f = Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(varinfo_linked, model, context) + DynamicPPL.LogDensityFunction(varinfo_linked, model, context), ) # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which # represent the _negative_ log-density. - mdl = MarginalLogDensity( + mdl = MarginalLogDensities.MarginalLogDensity( Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method ) return mdl end + +end diff --git a/src/Turing.jl b/src/Turing.jl index 1296678e9a..cc741e2890 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -59,6 +59,7 @@ include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release include("extensions.jl") +export marginalize ########### # Exports # From bde7bf5592c533b245957d89d6d8106c770e27d9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 08:42:57 +0000 Subject: [PATCH 3/6] added tests to testsuite --- test/ext/TuringMarginalLogDensitiesExt.jl | 16 ++++++++++++++++ test/runtests.jl | 4 ++++ 2 files changed, 20 insertions(+) create mode 100644 test/ext/TuringMarginalLogDensitiesExt.jl diff --git a/test/ext/TuringMarginalLogDensitiesExt.jl b/test/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 0000000000..2249653954 --- /dev/null +++ b/test/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,16 @@ +module TuringMarginalLogDensitiesExt + +using Turing, MarginalLogDensities, Test + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + model = demo(); + # Marginalize out `x`. + marginalized = marginalize(model, [@varname(x)]); + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, √2), 0.0) atol=2e-1 +end diff --git a/test/runtests.jl b/test/runtests.jl index 530219c83b..543e1dc565 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,10 @@ end @testset "utilities" begin @timeit_include("mcmc/utilities.jl") end + + @testset "extensions" begin + @timeit_include("ext/TuringMarginalLogDensitiesExt.jl") + end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From 95615689c783d9232b80091ad50da5d1816154b5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 08:53:35 +0000 Subject: [PATCH 4/6] bump compat bounds of MarginalLogDensities.jl to 0.3.6 where ForwardDiff.jl has been fiexd --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e09ee812b..b21c132dc4 100644 --- a/Project.toml +++ b/Project.toml @@ -72,7 +72,7 @@ Libtask = "0.8.8" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" -MarginalLogDensities = "0.3" +MarginalLogDensities = "0.3.6" MCMCChains = "5, 6" NamedArrays = "0.9, 0.10" Optim = "1" From ff593cb908179e19fd8e23b749ce5132faa877df Mon Sep 17 00:00:00 2001 From: Tor Fjelde Date: Mon, 9 Dec 2024 13:14:02 +0000 Subject: [PATCH 5/6] bump compat entry so we have the new `vector_getranges` --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b21c132dc4..bad56848a2 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.29, 0.30.4, 0.31" +DynamicPPL = "0.31.4" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" From 30c7c49188dec33b925fef66bbd03e9bc078b522 Mon Sep 17 00:00:00 2001 From: Tor Fjelde Date: Mon, 9 Dec 2024 13:14:26 +0000 Subject: [PATCH 6/6] replaced `getranges` with `vector_getranges` --- ext/TuringMarginalLogDensitiesExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl index 886b3a9e43..d26a6b4a64 100644 --- a/ext/TuringMarginalLogDensitiesExt.jl +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -18,7 +18,7 @@ function Turing.marginalize( ) # Determine the indices for the variables to marginalise out. varinfo = DynamicPPL.typed_varinfo(model) - varindices = DynamicPPL.getranges(varinfo, varnames) + varindices = DynamicPPL.vector_getranges(varinfo, varnames) # Construct the marginal log-density model. # Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure # that the log-abs-det jacobian terms are not included.