Skip to content

Commit

Permalink
Rename VecCholeskyBijector to VecCorrCholeskyBijector
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Mar 10, 2024
1 parent 2402be2 commit bd6ff3d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
30 changes: 16 additions & 14 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ A bijector to transform a correlation matrix to an unconstrained vector.
# Reference
https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html
See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref)
See also: [`CorrBijector`](@ref) and ['VecCorrCholeskyBijector'](@ref)
# Example
Expand Down Expand Up @@ -151,7 +151,7 @@ function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int})
end

"""
VecCholeskyBijector <: Bijector
VecCorrCholeskyBijector <: Bijector
A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector.
Expand All @@ -172,7 +172,7 @@ julia> using LinearAlgebra
julia> using StableRNGs; rng = StableRNG(42);
julia> b = Bijectors.VecCholeskyBijector(:U);
julia> b = Bijectors.VecCorrCholeskyBijector(:U);
julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix.
Cholesky{Float64, Matrix{Float64}}
Expand All @@ -194,9 +194,9 @@ true
julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor.
true
"""
struct VecCholeskyBijector <: Bijector
struct VecCorrCholeskyBijector <: Bijector
mode::Symbol
function VecCholeskyBijector(uplo)
function VecCorrCholeskyBijector(uplo)
s = Symbol(uplo)
if (s === :U) || (s === :L)
new(s)
Expand All @@ -210,39 +210,41 @@ struct VecCholeskyBijector <: Bijector
end
end

Base.@deprecate_binding VecCholeskyBijector VecCorrCholeskyBijector

# TODO: Implement directly to make use of shared computations.
with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x)
with_logabsdet_jacobian(b::VecCorrCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::VecCholeskyBijector, X)
function transform(b::VecCorrCholeskyBijector, X)
return if b.mode === :U
_link_chol_lkj_from_upper(cholesky_upper(X))
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor.
_link_chol_lkj_from_lower(cholesky_lower(X))
end
end

function logabsdetjac(b::VecCholeskyBijector, x)
function logabsdetjac(b::VecCorrCholeskyBijector, x)
return -logabsdetjac(inverse(b), b(x))
end

function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
function transform(b::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real})
if b.orig.mode === :U
# This Cholesky constructor is compatible with Julia v1.6
# for later versions Cholesky(::UpperTriangular) works
return Cholesky(_inv_link_chol_lkj(y), 'U', 0)
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor.
# HACK: Need to make materialize the transposed matrix to avoid numerical instabilities.
# If we don't, the return-type can be both `Matrix` and `Transposed`.
return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0)
end
end

function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
function logabsdetjac(::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_chol(y)
end

output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz)
function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int})
output_size(::VecCorrCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz)
function output_size(::Inverse{<:VecCorrCholeskyBijector}, sz::Tuple{Int})
return output_size(inverse(VecCorrBijector()), sz)
end

Expand Down
4 changes: 2 additions & 2 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Bijectors, DistributionsAD, LinearAlgebra, Test
using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
using Bijectors: VecCorrBijector, VecCorrCholeskyBijector, CorrBijector

@testset "CorrBijector & VecCorrBijector" begin
for d in [1, 2, 5]
Expand Down Expand Up @@ -45,7 +45,7 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
end
end

@testset "VecCholeskyBijector" begin
@testset "VecCorrCholeskyBijector" begin
for d in [2, 5]
for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')]
b = bijector(dist)
Expand Down

0 comments on commit bd6ff3d

Please sign in to comment.