Skip to content

Commit

Permalink
Fix for #255 + some other DistributionsAD-stuff (#259)
Browse files Browse the repository at this point in the history
* fix for #255 and introduction of columnwise

* added some tests

* version  bump

* forgot to add the inverse
  • Loading branch information
torfjelde authored Jun 3, 2023
1 parent 529f6f6 commit 00bf10c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.12.4"
version = "0.12.5"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
6 changes: 3 additions & 3 deletions src/compat/distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ bijector(::TuringScalMvNormal) = identity
bijector(::TuringDiagMvNormal) = identity
bijector(::TuringDenseMvNormal) = identity

bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value)
bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value))
bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value))
bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value))
bijector(d::MatrixOfUnivariate{Discrete}) = identity
bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...)
bijector(d::VectorOfMultivariate{Discrete}) = identity
Expand All @@ -30,7 +30,7 @@ for T in (:VectorOfMultivariate, :FillVectorOfMultivariate)
bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector()
end
end
bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value)
bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value))

isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true
isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true
Expand Down
15 changes: 15 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ elementwise(f) = Base.Fix1(broadcast, f)
# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really
# the way to go.
elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner))
const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F}
"""
Alias for `Base.Fix1(eachcolmaphcat, f)`.
Represents a function `f` which is applied to each column of an input.
"""
columnwise(f) = Base.Fix1(eachcolmaphcat, f)
inverse(f::Columnwise) = columnwise(inverse(f.x))

transform(f::Columnwise, x::AbstractMatrix) = f(x)
function logabsdetjac(f::Columnwise, x::AbstractMatrix)
return sum(Base.Fix1(logabsdetjac, f.x), eachcol(x))
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

######################
# Bijector interface #
Expand Down
16 changes: 16 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,22 @@ end
end
end

@testset "DistributionsAD" begin
@testset "$dist" for dist in [
filldist(Normal(), 2),
filldist(Normal(), 2, 3),
filldist(Exponential(), 2),
filldist(Exponential(), 2, 3),
filldist(filldist(Exponential(), 2), 3),
]
x = rand(dist)
b = bijector(dist)
y = b(x)
td = transformed(dist)
@test logpdf(dist, x) - logabsdetjac(b, x) logpdf(td, y)
end
end

@testset "Stacked <: Bijector" begin
# `logabsdetjac` withOUT AD
d = Beta()
Expand Down

2 comments on commit 00bf10c

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/84800

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.5 -m "<description of version>" 00bf10caebfe8300a7752f63d3a929d3698f5d43
git push origin v0.12.5

Please sign in to comment.