diff --git a/Project.toml b/Project.toml index 18d60250..80959f00 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.7" +version = "0.13.8" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/interface.jl b/src/interface.jl index 8d401e5a..1ba9aa60 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -42,6 +42,7 @@ with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac( Returns the output size of `f` given the input size `sz`. """ output_size(f, sz) = sz +output_size(f::ComposedFunction, sz) = output_size(f.outer, output_size(f.inner, sz)) """ output_length(f, len::Int) diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl index 655b63eb..1f221b39 100644 --- a/test/bijectors/stacked.jl +++ b/test/bijectors/stacked.jl @@ -34,4 +34,44 @@ end @test y == [exp(1.0), 2.0] @test binv(y) == [1.0, 2.0, 0.0] end + + @testset "composition" begin + # Composition with one dimension reduction. + b = Stacked((elementwise(exp), ProjectionBijector() ∘ identity), [1:1, 2:3]) + binv = inverse(b) + x = [1.0, 2.0, 3.0] + y = b(x) + x_ = binv(y) + + # Are the values of correct size? + @test size(y) == (2,) + @test size(x_) == (3,) + # Can we determine the sizes correctly? + @test Bijectors.output_size(b, size(x)) == (2,) + @test Bijectors.output_size(binv, size(y)) == (3,) + + # Are values correct? + @test y == [exp(1.0), 2.0] + @test binv(y) == [1.0, 2.0, 0.0] + + # Composition with two dimension reductions. + b = Stacked( + (elementwise(exp), ProjectionBijector() ∘ ProjectionBijector()), [1:1, 2:4] + ) + binv = inverse(b) + x = [1.0, 2.0, 3.0, 4.0] + y = b(x) + x_ = binv(y) + + # Are the values of correct size? + @test size(y) == (2,) + @test size(x_) == (4,) + # Can we determine the sizes correctly? + @test Bijectors.output_size(b, size(x)) == (2,) + @test Bijectors.output_size(binv, size(y)) == (4,) + + # Are values correct? + @test y == [exp(1.0), 2.0] + @test binv(y) == [1.0, 2.0, 0.0, 0.0] + end end