Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for higher-dim Dirichlet, e.g. product_distribution #586

Merged
merged 10 commits into from
Apr 18, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ADTypes = "0.2"
AbstractMCMC = "5"
AbstractPPL = "0.7"
BangBang = "0.3"
Bijectors = "0.13"
Bijectors = "0.13.9"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
Expand Down
10 changes: 10 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,16 @@ function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# NOTE: Necessary to handle product distributions of `Dirichlet` and similar.
function with_logabsdet_jacobian_and_reconstruct(
f::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, y
)
(d, ns...) = size(dist)
yreshaped = reshape(y, d - 1, ns...)
x, logjac = with_logabsdet_jacobian(f, yreshaped)
return x, logjac
end

# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
function reconstruct(
::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N}
) where {N}
return copy(val)
end
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)

function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
Expand Down
25 changes: 25 additions & 0 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,29 @@ end
end
end
end

# Related: https://github.com/TuringLang/Turing.jl/issues/2190
@testset "High-dim Dirichlet" begin
@model function demo_highdim_dirichlet(ns...)
return x ~ filldist(Dirichlet(ones(2)), ns...)
end
@testset "ns=$ns" for ns in [
(3,),
(3, 4),
(3, 4, 5)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
]
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
model = demo_highdim_dirichlet(ns...)
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Linked.
vi_linked = if mutable
DynamicPPL.link!!(deepcopy(vi), model)
else
DynamicPPL.link(vi, model)
end
@test length(vi_linked[:]) == prod(ns)
end
end
end
end
Loading