Skip to content

Commit

Permalink
fix for #255 and introduction of columnwise
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jun 3, 2023
1 parent 529f6f6 commit c3f26bd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/compat/distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ bijector(::TuringScalMvNormal) = identity
bijector(::TuringDiagMvNormal) = identity
bijector(::TuringDenseMvNormal) = identity

bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value)
bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value))
bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value))
bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value))
bijector(d::MatrixOfUnivariate{Discrete}) = identity
bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...)
bijector(d::VectorOfMultivariate{Discrete}) = identity
Expand All @@ -30,7 +30,7 @@ for T in (:VectorOfMultivariate, :FillVectorOfMultivariate)
bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector()
end
end
bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value)
bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value))

isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true
isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true
Expand Down
14 changes: 14 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ elementwise(f) = Base.Fix1(broadcast, f)
# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really
# the way to go.
elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner))
const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F}
"""
Alias for `Base.Fix1(eachcolmaphcat, f)`.
Represents a function `f` which is applied to each column of an input.
"""
columnwise(f) = Base.Fix1(eachcolmaphcat, f)

transform(f::Columnwise, x::AbstractMatrix) = f(x)
function logabsdetjac(f::Columnwise, x::AbstractMatrix)
return sum(Base.Fix1(logabsdetjac, f.x), eachcol(x))
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

######################
# Bijector interface #
Expand Down

0 comments on commit c3f26bd

Please sign in to comment.