From c3f26bdb67f3b69273ebc48c21079b96f109e6d3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 3 Jun 2023 10:18:07 +0100 Subject: [PATCH] fix for #255 and introduction of columnwise --- src/compat/distributionsad.jl | 6 +++--- src/interface.jl | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 080097f3..99dc81d5 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -13,8 +13,8 @@ bijector(::TuringScalMvNormal) = identity bijector(::TuringDiagMvNormal) = identity bijector(::TuringDenseMvNormal) = identity -bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value) -bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value)) +bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value)) bijector(d::MatrixOfUnivariate{Discrete}) = identity bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) bijector(d::VectorOfMultivariate{Discrete}) = identity @@ -30,7 +30,7 @@ for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() end end -bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value) +bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value)) isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true diff --git a/src/interface.jl b/src/interface.jl index a216a921..14724943 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -17,6 +17,20 @@ elementwise(f) = Base.Fix1(broadcast, f) # TODO: This is makes dispatching quite a bit easier, but uncertain if this is really # the way to go. elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) +const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F} +""" + +Alias for `Base.Fix1(eachcolmaphcat, f)`. + +Represents a function `f` which is applied to each column of an input. +""" +columnwise(f) = Base.Fix1(eachcolmaphcat, f) + +transform(f::Columnwise, x::AbstractMatrix) = f(x) +function logabsdetjac(f::Columnwise, x::AbstractMatrix) + return sum(Base.Fix1(logabsdetjac, f.x), eachcol(x)) +end +with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x)) ###################### # Bijector interface #