Skip to content

Commit

Permalink
AD fix for PDBijector (#280)
Browse files Browse the repository at this point in the history
* added cholesky_lower and cholesky_triangular

* updated PD to use new cholesky_lower and cholesky_upper

* simplified imports in BijectorsReverseDiffExtx

* added ChainRules as a dep since we need the chain rules for cholesky, etc.

* forgot to update Project.toml in previous commit

* added explicit implementation of with_logabsdet_jacobian for PDBijector

* Update src/utils.jl

* added ProjectTo in rrules for cholesky_lower and cholesky_upper to be proper

* added ProjectTo for cholesky_upper too

* added transpose_eager as a alias for permutedims to allow definition
of AD rules without type piracy

* allow usage of ForwardDiff gradient as ground-truth

* added AD tests for PDVecBijector

* added AD tests for PDVecBijector to runtests and commented out all
other tests for the sake of reproducing ReverseDiff bug

* forgot to remove type-piracy def of ReverseDiff rule for permutedims

* use ReverseDiff.@Grad instead of ReverseDiff.@grad_from_chainrules

* only define cholesky_lower and cholesky_upper rules for ReverseDiff, remove rules ChainRules defs

* formatting

* parameterise gradient test for PD bijector properly instead of using
ForwardDiff as per suggestion of @devmotion

* reversed chagne to test_ad

* reactivate tests

* updated doocstrings

* improved PDVecBijector AD tests a bit

* AD fix for CorrBijector (#281)

* removed redundant imports to BijectorsZygoteExt

* use cholesky_upper and cholesky_lower instead of cholesky_factor, etc.

* added tests for CorrVecBijector

* name testset correctly

* use cholesky_lower and cholesky_upper instead of cholesky_factor

* removed now-redundant cholesky_factor

* Fix obsolete function references in tests.  (#282)

* Update chainrules.jl

* Update corr.jl

* Revert changes to transform.

* removed type-piracy that has been addressed upstream and bumped Zygote
version in test

* use :L for Hermitian in `cholesky_lower`

* fixed ForwardDiff tests for LKJCholesky

* fixed tests for matrix dists and added tests for both values of uplo
in LKJCholesky tests

* another attempt at fixing Julia 1.6 tests

---------

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai authored Aug 12, 2023
1 parent 74d52d4 commit df21aef
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 75 deletions.
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.13.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -22,20 +23,20 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsZygoteExt = "Zygote"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsDistributionsADExt = "DistributionsAD"

[compat]
ArgCheck = "1, 2"
Expand Down
63 changes: 53 additions & 10 deletions ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if isdefined(Base, :get_extension)
simplex_logabsdetjac_gradient,
Inverse
import Bijectors:
Bijectors,
_eps,
logabsdetjac,
_logabsdetjac_scale,
Expand All @@ -35,7 +36,8 @@ if isdefined(Base, :get_extension)
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using Bijectors.LinearAlgebra
using Bijectors.Compat: eachcol
Expand All @@ -61,6 +63,7 @@ else
simplex_logabsdetjac_gradient,
Inverse
import ..Bijectors:
Bijectors,
_eps,
logabsdetjac,
_logabsdetjac_scale,
Expand All @@ -75,7 +78,8 @@ else
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using ..Bijectors.LinearAlgebra
using ..Bijectors.Compat: eachcol
Expand Down Expand Up @@ -253,18 +257,57 @@ end
@grad_from_chainrules _transform_ordered(y::Union{TrackedVector,TrackedMatrix})
@grad_from_chainrules _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix})

@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
@grad_from_chainrules Bijectors.update_triu_from_vec(
vals::TrackedVector{<:Real}, k::Int, dim::Int
)

@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)

if VERSION <= v"1.8.0-DEV.1526"
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
# cholesky does not work with AbstractMatrix in julia versions before the compared one,
# and it would error with Hermitian{ReverseDiff.TrackedArray}.
# See commit when the fix was introduced :
# https://github.com/JuliaLang/julia/commit/635449dabee81bba315ab066627a98f856141969
cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X))
cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X)
@grad function cholesky_lower(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_lower_pullback(ΔL)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `lower_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the lower-triangular
# part zeroed out).
return (Δx,)
end

return lower_triangular(parent(C.L)), cholesky_lower_pullback
end

cholesky_upper(X::TrackedMatrix) = track(cholesky_upper, X)
@grad function cholesky_upper(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_upper_pullback(ΔU)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `upper_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the upper-triangular
# part zeroed out).
return (Δx,)
end

return upper_triangular(parent(C.U)), cholesky_upper_pullback
end

transpose_eager(X::TrackedMatrix) = track(transpose_eager, X)
@grad function transpose_eager(X_tracked::TrackedMatrix)
X = value(X_tracked)
y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1))
transpose_eager_pullback(Δ) = (y_pullback(Δ)[2],)
return y, transpose_eager_pullback
end

end
5 changes: 1 addition & 4 deletions ext/BijectorsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ if isdefined(Base, :get_extension)
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -55,8 +53,6 @@ else
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -244,4 +240,5 @@ end
return replace_diag(log, Y)
end
end

end
1 change: 1 addition & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian
import InverseFunctions: inverse

using ChainRulesCore: ChainRulesCore
using ChainRules: ChainRules
using Functors: Functors
using IrrationalConstants: IrrationalConstants
using LogExpFunctions: LogExpFunctions
Expand Down
22 changes: 13 additions & 9 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ struct CorrBijector <: Bijector end
with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::CorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation
w = cholesky_upper(X)
r = _link_chol_lkj(w)
return r + zero(X)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
return r
end

function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
Expand Down Expand Up @@ -127,7 +125,7 @@ struct VecCorrBijector <: Bijector end

with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X))
transform(::VecCorrBijector, X) = _link_chol_lkj_from_upper(cholesky_upper(X))

function logabsdetjac(b::VecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand Down Expand Up @@ -215,7 +213,13 @@ end
# TODO: Implement directly to make use of shared computations.
with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X))
function transform(b::VecCholeskyBijector, X)
return if b.mode === :U
_link_chol_lkj_from_upper(cholesky_upper(X))
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
_link_chol_lkj_from_lower(cholesky_lower(X))
end
end

function logabsdetjac(b::VecCholeskyBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand All @@ -229,7 +233,7 @@ function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
# HACK: Need to make materialize the transposed matrix to avoid numerical instabilities.
# If we don't, the return-type can be both `Matrix` and `Transposed`.
return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0)
return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0)
end
end

Expand Down Expand Up @@ -299,7 +303,7 @@ function _link_chol_lkj(W::AbstractMatrix)
return z
end

function _link_chol_lkj(W::UpperTriangular)
function _link_chol_lkj_from_upper(W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters

Expand All @@ -321,7 +325,7 @@ function _link_chol_lkj(W::UpperTriangular)
return z
end

_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W))
_link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W))

"""
_inv_link_chol_lkj(y)
Expand Down
31 changes: 11 additions & 20 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,37 @@ function replace_diag(f, X)
return g.(1:size(X, 1), (1:size(X, 2))')
end
transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X)
function pd_link(X)
Y = lower_triangular(parent(cholesky(X; check=true).L))
return replace_diag(log, Y)
end
pd_link(X) = replace_diag(log, cholesky_lower(X))

function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real})
X = replace_diag(exp, Y)
return pd_from_lower(X)
end

function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real})
T = eltype(X)
Xcf = cholesky(X; check=false)
if !issuccess(Xcf)
Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I)
end
return logabsdetjac_pdbijector_chol(Xcf)
L = cholesky_lower(X)
return logabsdetjac_pdbijector_chol(L)
end

function logabsdetjac_pdbijector_chol(Xcf::Cholesky)
# NOTE: Use `UpperTriangular` here because we only need `diag(U)`
# and `UL` is by default already constructed in `Cholesky`.
UL = Xcf.UL
d = size(UL, 1)
z = sum(((d + 1):(-1):2) .* log.(diag(UL)))
function logabsdetjac_pdbijector_chol(X::AbstractMatrix)
d = size(X, 1)
z = sum(((d + 1):(-1):2) .* log.(diag(X)))
return -(z + d * oftype(z, IrrationalConstants.logtwo))
end

# TODO: Implement explicitly.
function with_logabsdet_jacobian(b::PDBijector, X)
return transform(b, X), logabsdetjac(b, X)
L = cholesky_lower(X)
return replace_diag(log, L), logabsdetjac_pdbijector_chol(L)
end

struct PDVecBijector <: Bijector end

transform(::PDVecBijector, X::AbstractMatrix{<:Real}) = pd_vec_link(X)
pd_vec_link(X) = triu_to_vec(transpose(pd_link(X)))
# TODO: Implement `tril_to_vec` and remove `permutedims`.
pd_vec_link(X) = triu_to_vec(transpose_eager(pd_link(X)))

function transform(::Inverse{PDVecBijector}, y::AbstractVector{<:Real})
Y = permutedims(vec_to_triu(y))
Y = transpose_eager(vec_to_triu(y))
return transform(inverse(PDBijector()), Y)
end

Expand Down
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM
return y, _transform_inverse_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_upper), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -178,7 +178,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_upper(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -208,10 +208,10 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_upper
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -233,7 +233,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_lower(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -263,7 +263,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_lower
end

function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector)
Expand Down
33 changes: 29 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,35 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))
pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)'
pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X)

cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X)))
cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X
# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)

# TODO: Add `check` as an argument?
"""
cholesky_lower(X)
Return the lower triangular Cholesky factor of `X` as a `Matrix`
rather than `LowerTriangular`.
!!! note
This is a thin wrapper around `cholesky(Hermitian(X)).L`
that returns a `Matrix` rather than `LowerTriangular`.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L))
cholesky_lower(X::Cholesky) = X.L

"""
cholesky_upper(X)
Return the upper triangular Cholesky factor of `X` as a `Matrix`
rather than `UpperTriangular`.
!!! note
This is a thin wrapper around `cholesky(Hermitian(X)).U`
that returns a `Matrix` rather than `UpperTriangular`.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))
cholesky_upper(X::Cholesky) = X.U

"""
triu_mask(X::AbstractMatrix, k::Int)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ LazyArrays = "1"
LogExpFunctions = "0.3.1"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
Zygote = "0.6.63"
julia = "1.3"
4 changes: 2 additions & 2 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# LKJ and LKJCholesky bijector
dist = LKJCholesky(3, 4)
x = rand(dist)
test_rrule(Bijectors._link_chol_lkj, x.U)
test_rrule(Bijectors._link_chol_lkj, x.L)
test_rrule(Bijectors._link_chol_lkj_from_upper, x.U)
test_rrule(Bijectors._link_chol_lkj_from_lower, x.L)

b = bijector(dist)
y = b(x)
Expand Down
Loading

0 comments on commit df21aef

Please sign in to comment.