From da9650c0b05e5ae21f728989158fada9377e947c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 29 Aug 2024 13:10:19 +0200 Subject: [PATCH] Fix Dirichlet logpdf_with_trans to work with a Vector{Real} (#326) --- src/Bijectors.jl | 2 +- test/transform.jl | 31 +++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index df7814fb..55586731 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -164,7 +164,7 @@ function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) return pd_logpdf_with_trans(d, x, transform) elseif isdirichlet(d) - l = logpdf(d, x .+ eps(eltype(x))) + l = logpdf(d, x .+ _eps(eltype(x))) else l = logpdf(d, x) end diff --git a/test/transform.jl b/test/transform.jl index d08b2dff..85477535 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -42,18 +42,41 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) + _single_sample_tests_inner(dist, x, ϵ) + # If the sample is a vector of scalars, check that we can run the tests even if the + # vector has the abstract element type Real. Skip type stability tests though. + if x isa Vector{<:Real} + _single_sample_tests_inner(dist, Vector{Real}(x), ϵ, false) + end +end + +function _single_sample_tests_inner(dist, x, ϵ, test_type_stability=true) if dist isa LKJCholesky x_inv = @inferred Cholesky{Float64,Matrix{Float64}} invlink( dist, link(dist, copy(x)) ) @test x_inv.UL ≈ x.UL atol = 1e-9 else - @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol = 1e-9 + x_reconstructed = if test_type_stability + @inferred invlink(dist, link(dist, copy(x))) + else + invlink(dist, link(dist, copy(x))) + end + @test x_reconstructed ≈ x atol = 1e-9 end # Check that link is inverse of invlink. Hopefully this just holds given the above... - y = @inferred(link(dist, x)) + y = if test_type_stability + @inferred(link(dist, x)) + else + link(dist, x) + end + y_reconstructed = if test_type_stability + @inferred(link(dist, invlink(dist, copy(y)))) + else + link(dist, invlink(dist, copy(y))) + end if dist isa Dirichlet # `logit` and `logistic` are not perfect inverses. This leads to a diversion. # Example: @@ -61,9 +84,9 @@ function single_sample_tests(dist) # 1.0 # julia> logistic(logit(0.9999999999999998)) # 0.9999999999999998 - @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol = 0.5 + @test y_reconstructed ≈ y atol = 0.5 else - @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol = 1e-9 + @test y_reconstructed ≈ y atol = 1e-9 end if dist isa SimplexDistribution # This should probably be exact.