From 00bf10caebfe8300a7752f63d3a929d3698f5d43 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 3 Jun 2023 11:10:21 +0100 Subject: [PATCH] Fix for #255 + some other DistributionsAD-stuff (#259) * fix for #255 and introduction of columnwise * added some tests * version bump * forgot to add the inverse --- Project.toml | 2 +- src/compat/distributionsad.jl | 6 +++--- src/interface.jl | 15 +++++++++++++++ test/interface.jl | 16 ++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index a4ba8114..98f1908f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.4" +version = "0.12.5" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 080097f3..99dc81d5 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -13,8 +13,8 @@ bijector(::TuringScalMvNormal) = identity bijector(::TuringDiagMvNormal) = identity bijector(::TuringDenseMvNormal) = identity -bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value) -bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value)) +bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value)) bijector(d::MatrixOfUnivariate{Discrete}) = identity bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) bijector(d::VectorOfMultivariate{Discrete}) = identity @@ -30,7 +30,7 @@ for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() end end -bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value) +bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value)) isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true diff --git a/src/interface.jl b/src/interface.jl index a216a921..2efcac98 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -17,6 +17,21 @@ elementwise(f) = Base.Fix1(broadcast, f) # TODO: This is makes dispatching quite a bit easier, but uncertain if this is really # the way to go. elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) +const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F} +""" + +Alias for `Base.Fix1(eachcolmaphcat, f)`. + +Represents a function `f` which is applied to each column of an input. +""" +columnwise(f) = Base.Fix1(eachcolmaphcat, f) +inverse(f::Columnwise) = columnwise(inverse(f.x)) + +transform(f::Columnwise, x::AbstractMatrix) = f(x) +function logabsdetjac(f::Columnwise, x::AbstractMatrix) + return sum(Base.Fix1(logabsdetjac, f.x), eachcol(x)) +end +with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x)) ###################### # Bijector interface # diff --git a/test/interface.jl b/test/interface.jl index b24f9da0..e1f8d0e4 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -196,6 +196,22 @@ end end end +@testset "DistributionsAD" begin + @testset "$dist" for dist in [ + filldist(Normal(), 2), + filldist(Normal(), 2, 3), + filldist(Exponential(), 2), + filldist(Exponential(), 2, 3), + filldist(filldist(Exponential(), 2), 3), + ] + x = rand(dist) + b = bijector(dist) + y = b(x) + td = transformed(dist) + @test logpdf(dist, x) - logabsdetjac(b, x) ≈ logpdf(td, y) + end +end + @testset "Stacked <: Bijector" begin # `logabsdetjac` withOUT AD d = Beta()