Skip to content

Commit

Permalink
Bijector for product distribution (#304)
Browse files Browse the repository at this point in the history
* added has_constant_bijector and made bijector of product distributions
return the identity whenever possible

* no need to limit ourselves to identity for constant bijectors

* no need to limit ourselves to identity for Product

* bump patch version

* Update test/bijectors/ordered.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed tests

* attempt at fix for ordered MvTDist test

* dispatch on GenericMvTDist instead of TDist

* Update test/bijectors/ordered.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added some tests for MvTDist

* make elementwise acting on identity return identity

* fixed bug in error

* added `ProductBijector` which defines bijectors for
`Distributions.ProductDistribution` and similars

* removed redudant comment

* added FillArrays as a test dep

* removed comment

* added tests

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored May 17, 2024
1 parent d364639 commit 650c548
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/bijectors/product_bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
struct ProductBijector{Bs,N} <: Transform
bs::Bs
end

ProductBijector(bs::AbstractArray{<:Any,N}) where {N} = ProductBijector{typeof(bs),N}(bs)

inverse(b::ProductBijector) = ProductBijector(map(inverse, b.bs))

function _product_bijector_check_dim(::Val{N}, ::Val{M}) where {N,M}
if N > M
throw(
DimensionMismatch(
"Number of bijectors needs to be smaller than or equal to the number of dimensions",
),
)
end
end

function _product_bijector_slices(
::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M}
) where {N,M}
_product_bijector_check_dim(Val(N), Val(M))

# If N < M, then the bijectors expect an input vector of dimension `M - N`.
# To achieve this, we need to slice along the last `N` dimensions.
return eachslice(x; dims=ntuple(i -> i + (M - N), N))
end

# Specialization for case where we're just applying elementwise.
function transform(
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N}
) where {N}
return map(transform, b.bs, x)
end
# General case.
function transform(
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M}
) where {N,M}
slices = _product_bijector_slices(b, x)
return stack(map(transform, b.bs, slices))
end

function with_logabsdet_jacobian(
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N}
) where {N}
results = map(with_logabsdet_jacobian, b.bs, x)
return map(first, results), sum(last, results)
end
function with_logabsdet_jacobian(
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M}
) where {N,M}
slices = _product_bijector_slices(b, x)
results = map(with_logabsdet_jacobian, b.bs, slices)
return stack(map(first, results)), sum(last, results)
end

# Other utilities.
function output_size(b::ProductBijector{<:AbstractArray,N}, sz::NTuple{M}) where {N,M}
_product_bijector_check_dim(Val(N), Val(M))

sz_redundant = ntuple(i -> sz[i + (M - N)], N)
sz_example = ntuple(i -> sz[i], M - N)
# NOTE: `Base.stack`, which is used in the transformation, only supports the scenario where
# all `b.bs` have the same output sizes => only need to check the first one.
return (output_size(first(b.bs), sz_example)..., sz_redundant...)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ include("bijectors/corr.jl")
include("bijectors/truncated.jl")
include("bijectors/named_bijector.jl")
include("bijectors/ordered.jl")
include("bijectors/product_bijector.jl")

# Normalizing flow related
include("bijectors/planar_layer.jl")
Expand Down
21 changes: 21 additions & 0 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ has_constant_bijector(d::Type{<:KSOneSided}) = true
function has_constant_bijector(::Type{<:Product{Continuous,D}}) where {D}
return has_constant_bijector(D)
end
function has_constant_bijector(
::Type{<:Distributions.ProductDistribution{<:Any,<:Any,A}}
) where {A}
return has_constant_bijector(eltype(A))
end

# Container distributions.
bijector(d::DiscreteUnivariateDistribution) = identity
Expand Down Expand Up @@ -93,6 +98,22 @@ end
end
end

function bijector(d::Distributions.ProductDistribution{N,0,A}) where {N,A}
# This is the univariate scenario, so if we have a constant bijector
# we can just use the same one for all elements.
return if has_constant_bijector(eltype(A))
elementwise(bijector(d.dists[1]))
else
ProductBijector(map(bijector, d.dists))
end
end

function bijector(d::Distributions.ProductDistribution{N,M,A}) where {N,M,A}
dists = d.dists
bs = bijector.(dists)
return ProductBijector{typeof(bs),N - M}(bs)
end

# Specialized implementations.
bijector(d::Normal) = identity
bijector(d::Distributions.AbstractMvNormal) = identity
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -21,6 +22,7 @@ ChainRulesTestUtils = "0.7, 1"
ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
DistributionsAD = "0.6.3"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
Functors = "0.1, 0.2, 0.3, 0.4"
Expand Down
69 changes: 69 additions & 0 deletions test/bijectors/product_bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Bijectors: ProductBijector
using FillArrays

has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x)

@testset "ProductBijector" begin
# Some distributions.
ds = [
# 1D.
(Normal(), true),
(InverseGamma(), false),
(Beta(), false),
# 2D.
(MvNormal(Zeros(3), I), true),
(Dirichlet(Ones(3)), false),
]

# Stacking a single dimension.
N = 4
@testset "Single-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds
b = bijector(d)
xs = [rand(d) for _ in 1:N]
x = stack(xs)

d_prod = product_distribution(Fill(d, N))
b_prod = bijector(d_prod)

sz_true = (Bijectors.output_size(b, size(xs[1]))..., N)
@test Bijectors.output_size(b_prod, size(x)) == sz_true

results = map(xs) do x
with_logabsdet_jacobian(b, x)
end
y, logjac = stack(map(first, results)), sum(last, results)

test_bijector(
b_prod,
x;
y,
logjac,
changes_of_variables_test=has_square_jacobian(b, xs[1]),
test_not_identity=!isidentity,
)
end

@testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds
b = bijector(d)
xs = [rand(d) for _ in 1:N, _ in 1:(N + 1)]
x = stack(xs)

d_prod = product_distribution(Fill(d, N, N + 1))
b_prod = bijector(d_prod)

sz_true = (Bijectors.output_size(b, size(xs[1]))..., N, N + 1)
@test Bijectors.output_size(b_prod, size(x)) == sz_true

results = map(Base.Fix1(with_logabsdet_jacobian, b), xs)
y, logjac = stack(map(first, results)), sum(last, results)

test_bijector(
b_prod,
x;
y,
logjac,
changes_of_variables_test=has_square_jacobian(b, xs[1]),
test_not_identity=!isidentity,
)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ if GROUP == "All" || GROUP == "Interface"
include("bijectors/pd.jl")
include("bijectors/reshape.jl")
include("bijectors/corr.jl")
include("bijectors/product_bijector.jl")

include("distributionsad.jl")
end
Expand Down

0 comments on commit 650c548

Please sign in to comment.