Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Turing v0.33 #189

Merged
merged 20 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ ADTypes = "0.2, 1"
Accessors = "0.1.12"
Distributions = "0.25.87"
DynamicHMC = "3.4.0"
DynamicPPL = "0.24.7, 0.25, 0.27"
Folds = "0.2.2"
DynamicPPL = "0.25.2, 0.27"
Folds = "0.2.9"
ForwardDiff = "0.10.19"
IrrationalConstants = "0.1.1, 0.2"
LinearAlgebra = "1.6"
Expand All @@ -61,8 +61,8 @@ ReverseDiff = "1.4.5"
SciMLBase = "1.95.0, 2"
Statistics = "1.6"
StatsBase = "0.33.7, 0.34"
Transducers = "0.4.66"
Turing = "0.30.5, 0.31, 0.32"
Transducers = "0.4.81"
Turing = "0.31.4, 0.32, 0.33"
UnPack = "1"
julia = "1.6"

Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ StatsFuns = "1"
StatsPlots = "0.14.21, 0.15"
TransformVariables = "0.6.2, 0.7, 0.8"
TransformedLogDensities = "1.0.2"
Turing = "0.30.5, 0.31, 0.32"
Turing = "0.31.4, 0.32, 0.33"
188 changes: 72 additions & 116 deletions ext/PathfinderTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,149 +2,105 @@ module PathfinderTuringExt

if isdefined(Base, :get_extension)
using Accessors: Accessors
using ADTypes: ADTypes
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
using Pathfinder: Pathfinder
using Random: Random
using Turing: Turing
import Pathfinder: flattened_varnames_list
else # using Requires
using ..Accessors: Accessors
using ..ADTypes: ADTypes
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
using ..Pathfinder: Pathfinder
using ..Random: Random
using ..Turing: Turing
import ..Pathfinder: flattened_varnames_list
end

# utilities for working with Turing model parameter names using only the DynamicPPL API

function Pathfinder.flattened_varnames_list(model::DynamicPPL.Model)
varnames_ranges = varnames_to_ranges(model)
nsyms = maximum(maximum, values(varnames_ranges))
syms = Vector{Symbol}(undef, nsyms)
for (var_name, range) in varnames_to_ranges(model)
sym = Symbol(var_name)
if length(range) == 1
syms[range[begin]] = sym
continue
end
for i in eachindex(range)
syms[range[i]] = Symbol("$sym[$i]")
end
end
return syms
end

# code snippet shared by @torfjelde
"""
varnames_to_ranges(model::DynamicPPL.Model)
varnames_to_ranges(model::DynamicPPL.VarInfo)
varnames_to_ranges(model::DynamicPPL.Metadata)
create_log_density_problem(model::DynamicPPL.Model)

Get `Dict` mapping variable names in model to their ranges in a corresponding parameter vector.
Create a log density problem from a `model`.

# Examples
The return value is an object implementing the LogDensityProblems API whose log-density is
that of the `model` transformed to unconstrained space with the appropriate log-density
adjustment due to change of variables.
"""
function create_log_density_problem(model::DynamicPPL.Model)
# create an unconstrained VarInfo
varinfo = DynamicPPL.link(DynamicPPL.VarInfo(model), model)
# DefaultContext ensures that the log-density adjustment is computed
prob = DynamicPPL.LogDensityFunction(varinfo, model, DynamicPPL.DefaultContext())
return prob
end

```julia
julia> @model function demo()
s ~ Dirac(1)
x = Matrix{Float64}(undef, 2, 4)
x[1, 1] ~ Dirac(2)
x[2, 1] ~ Dirac(3)
x[3] ~ Dirac(4)
y ~ Dirac(5)
x[4] ~ Dirac(6)
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)])
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)])
return s, x, y
end
demo (generic function with 2 methods)
"""
draws_to_chains(model::DynamicPPL.Model, draws) -> MCMCChains.Chains

julia> demo()()
(1, Any[2.0 4.0 7 10; 3.0 6.0 8 9], 5)
Convert a `(nparams, ndraws)` matrix of unconstrained `draws` to an `MCMCChains.Chains`
object with corresponding constrained draws and names according to `model`.
"""
function draws_to_chains(model::DynamicPPL.Model, draws::AbstractMatrix)
varinfo = DynamicPPL.link(DynamicPPL.VarInfo(model), model)
draw_con_varinfos = map(eachcol(draws)) do draw
# this re-evaluates the model but allows supporting dynamic bijectors
# https://github.com/TuringLang/Turing.jl/issues/2195
return Turing.Inference.getparams(model, DynamicPPL.unflatten(varinfo, draw))
end
param_con_names = map(first, first(draw_con_varinfos))
draws_con = reduce(
vcat, Iterators.map(transpose ∘ Base.Fix1(map, last), draw_con_varinfos)
)
chns = MCMCChains.Chains(draws_con, param_con_names)
return chns
end

julia> varnames_to_ranges(demo())
Dict{AbstractPPL.VarName, UnitRange{Int64}} with 8 entries:
s => 1:1
x[4] => 5:5
x[:,3] => 6:7
x[1,1] => 2:2
x[2,1] => 3:3
x[[2, 1],4] => 8:9
x[3] => 4:4
y => 10:10
```
"""
function varnames_to_ranges end
transform_to_constrained(
p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model
)

varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model))
function varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo)
return varnames_to_ranges(varinfo.metadata)
Transform a vector of parameters `p` from unconstrained to constrained space.
"""
function transform_to_constrained(
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model
)
p = copy(p)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
@assert DynamicPPL.istrans(vi)
vi = DynamicPPL.unflatten(vi, p)
p .= DynamicPPL.invlink!!(vi, model)[:]
# Restore the linking, since we mutated vi.
DynamicPPL.link!!(vi, model)
return p
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo)
offset = 0
dicts = map(varinfo.metadata) do md
vns2ranges = varnames_to_ranges(md)
vals = collect(values(vns2ranges))
vals_offset = map(r -> offset .+ r, vals)
offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0)
Dict(zip(keys(vns2ranges), vals_offset))
end

return reduce(merge, dicts)
end
function varnames_to_ranges(metadata::DynamicPPL.Metadata)
idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns)
ranges = metadata.ranges[idcs]
return Dict(zip(metadata.vns, ranges))
end
function Pathfinder.pathfinder(model::DynamicPPL.Model; kwargs...)
log_density_problem = create_log_density_problem(model)
result = Pathfinder.pathfinder(log_density_problem; input=model, kwargs...)

function Pathfinder.pathfinder(
model::DynamicPPL.Model;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=Pathfinder.UniformSampler(init_scale),
init=nothing,
adtype::ADTypes.AbstractADType=Pathfinder.default_ad(),
kwargs...,
)
var_names = flattened_varnames_list(model)
prob = Turing.optim_problem(
model, Turing.MAP(); constrained=false, init_theta=init, adtype
)
init_sampler(rng, prob.prob.u0)
result = Pathfinder.pathfinder(prob.prob; rng, input=model, kwargs...)
draws = reduce(vcat, transpose.(prob.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
# add transformed draws as Chains
chains_info = (; pathfinder_result=result)
chains = Accessors.@set draws_to_chains(model, result.draws).info = chains_info
result_new = Accessors.@set result.draws_transformed = chains
return result_new
end

function Pathfinder.multipathfinder(
model::DynamicPPL.Model,
ndraws::Int;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=Pathfinder.UniformSampler(init_scale),
nruns::Int,
adtype=Pathfinder.default_ad(),
kwargs...,
)
var_names = flattened_varnames_list(model)
fun = Turing.optim_function(model, Turing.MAP(); constrained=false, adtype)
init1 = fun.init()
init = [init_sampler(rng, init1)]
for _ in 2:nruns
push!(init, init_sampler(rng, deepcopy(init1)))
function Pathfinder.multipathfinder(model::DynamicPPL.Model, ndraws::Int; kwargs...)
log_density_problem = create_log_density_problem(model)
result = Pathfinder.multipathfinder(log_density_problem, ndraws; input=model, kwargs...)

# add transformed draws as Chains
chains_info = (; pathfinder_result=result)
chains = Accessors.@set draws_to_chains(model, result.draws).info = chains_info

# add transformed draws as Chains for each individual path
single_path_results_new = map(result.pathfinder_results) do r
single_chains_info = (; pathfinder_result=r)
single_chains = Accessors.@set draws_to_chains(model, r.draws).info =
single_chains_info
r_new = Accessors.@set r.draws_transformed = single_chains
return r_new
end
result = Pathfinder.multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...)
draws = reduce(vcat, transpose.(fun.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns

result_new = Accessors.@set (Accessors.@set result.draws_transformed =
chains).pathfinder_results = single_path_results_new
return result_new
end

Expand Down
2 changes: 0 additions & 2 deletions src/Pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ include("resample.jl")
include("singlepath.jl")
include("multipath.jl")

include("integration/turing.jl")

function __init__()
Requires.@require AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" begin
include("integration/advancedhmc.jl")
Expand Down
41 changes: 0 additions & 41 deletions src/integration/turing.jl

This file was deleted.

5 changes: 4 additions & 1 deletion test/integration/Turing/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
LogDensityProblems = "2.1.0"
Pathfinder = "0.9"
Turing = "0.30.5, 0.31, 0.32"
Turing = "0.31.4, 0.32, 0.33"
julia = "1.6"
Loading
Loading