-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bijector for product distribution (#304)
* added has_constant_bijector and made bijector of product distributions return the identity whenever possible * no need to limit ourselves to identity for constant bijectors * no need to limit ourselves to identity for Product * bump patch version * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed tests * attempt at fix for ordered MvTDist test * dispatch on GenericMvTDist instead of TDist * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added some tests for MvTDist * make elementwise acting on identity return identity * fixed bug in error * added `ProductBijector` which defines bijectors for `Distributions.ProductDistribution` and similars * removed redudant comment * added FillArrays as a test dep * removed comment * added tests --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
d364639
commit 650c548
Showing
6 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
struct ProductBijector{Bs,N} <: Transform | ||
bs::Bs | ||
end | ||
|
||
ProductBijector(bs::AbstractArray{<:Any,N}) where {N} = ProductBijector{typeof(bs),N}(bs) | ||
|
||
inverse(b::ProductBijector) = ProductBijector(map(inverse, b.bs)) | ||
|
||
function _product_bijector_check_dim(::Val{N}, ::Val{M}) where {N,M} | ||
if N > M | ||
throw( | ||
DimensionMismatch( | ||
"Number of bijectors needs to be smaller than or equal to the number of dimensions", | ||
), | ||
) | ||
end | ||
end | ||
|
||
function _product_bijector_slices( | ||
::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} | ||
) where {N,M} | ||
_product_bijector_check_dim(Val(N), Val(M)) | ||
|
||
# If N < M, then the bijectors expect an input vector of dimension `M - N`. | ||
# To achieve this, we need to slice along the last `N` dimensions. | ||
return eachslice(x; dims=ntuple(i -> i + (M - N), N)) | ||
end | ||
|
||
# Specialization for case where we're just applying elementwise. | ||
function transform( | ||
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N} | ||
) where {N} | ||
return map(transform, b.bs, x) | ||
end | ||
# General case. | ||
function transform( | ||
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} | ||
) where {N,M} | ||
slices = _product_bijector_slices(b, x) | ||
return stack(map(transform, b.bs, slices)) | ||
end | ||
|
||
function with_logabsdet_jacobian( | ||
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N} | ||
) where {N} | ||
results = map(with_logabsdet_jacobian, b.bs, x) | ||
return map(first, results), sum(last, results) | ||
end | ||
function with_logabsdet_jacobian( | ||
b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} | ||
) where {N,M} | ||
slices = _product_bijector_slices(b, x) | ||
results = map(with_logabsdet_jacobian, b.bs, slices) | ||
return stack(map(first, results)), sum(last, results) | ||
end | ||
|
||
# Other utilities. | ||
function output_size(b::ProductBijector{<:AbstractArray,N}, sz::NTuple{M}) where {N,M} | ||
_product_bijector_check_dim(Val(N), Val(M)) | ||
|
||
sz_redundant = ntuple(i -> sz[i + (M - N)], N) | ||
sz_example = ntuple(i -> sz[i], M - N) | ||
# NOTE: `Base.stack`, which is used in the transformation, only supports the scenario where | ||
# all `b.bs` have the same output sizes => only need to check the first one. | ||
return (output_size(first(b.bs), sz_example)..., sz_redundant...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
using Bijectors: ProductBijector | ||
using FillArrays | ||
|
||
has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) | ||
|
||
@testset "ProductBijector" begin | ||
# Some distributions. | ||
ds = [ | ||
# 1D. | ||
(Normal(), true), | ||
(InverseGamma(), false), | ||
(Beta(), false), | ||
# 2D. | ||
(MvNormal(Zeros(3), I), true), | ||
(Dirichlet(Ones(3)), false), | ||
] | ||
|
||
# Stacking a single dimension. | ||
N = 4 | ||
@testset "Single-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds | ||
b = bijector(d) | ||
xs = [rand(d) for _ in 1:N] | ||
x = stack(xs) | ||
|
||
d_prod = product_distribution(Fill(d, N)) | ||
b_prod = bijector(d_prod) | ||
|
||
sz_true = (Bijectors.output_size(b, size(xs[1]))..., N) | ||
@test Bijectors.output_size(b_prod, size(x)) == sz_true | ||
|
||
results = map(xs) do x | ||
with_logabsdet_jacobian(b, x) | ||
end | ||
y, logjac = stack(map(first, results)), sum(last, results) | ||
|
||
test_bijector( | ||
b_prod, | ||
x; | ||
y, | ||
logjac, | ||
changes_of_variables_test=has_square_jacobian(b, xs[1]), | ||
test_not_identity=!isidentity, | ||
) | ||
end | ||
|
||
@testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds | ||
b = bijector(d) | ||
xs = [rand(d) for _ in 1:N, _ in 1:(N + 1)] | ||
x = stack(xs) | ||
|
||
d_prod = product_distribution(Fill(d, N, N + 1)) | ||
b_prod = bijector(d_prod) | ||
|
||
sz_true = (Bijectors.output_size(b, size(xs[1]))..., N, N + 1) | ||
@test Bijectors.output_size(b_prod, size(x)) == sz_true | ||
|
||
results = map(Base.Fix1(with_logabsdet_jacobian, b), xs) | ||
y, logjac = stack(map(first, results)), sum(last, results) | ||
|
||
test_bijector( | ||
b_prod, | ||
x; | ||
y, | ||
logjac, | ||
changes_of_variables_test=has_square_jacobian(b, xs[1]), | ||
test_not_identity=!isidentity, | ||
) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters