Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cholesky numerical stability: Forward transform #357

Draft
wants to merge 1 commit into
base: py/chol-numerical
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Dec 1, 2024

This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:

using Bijectors
using Distributions

θ_unconstrained = [
	-1.9887091960524537,
	-13.499454444466279,
	-0.39328331954134665,
	-4.426097270849902,
	13.101175413857023,
	7.66647404712346,
	9.249285786544894,
	4.714877413573335,
	6.233118490809442,
	22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)

# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.

Introduction

The forward transform acts on an upper triangular matrix, W, which is supposed to have unit vectors for each column, i.e. sum(W[:, j] .^ 2) should be 1 for each j:

julia> s = rand(LKJCholesky(5, 1.0, 'U')).U
5×5 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.345448  -0.478      0.455158   0.385151
     0.938438  -0.331921  -0.305083  -0.0469749
               0.813231  -0.397178   0.831726
                         0.73621    0.0298828
                                   0.395968

julia> [sum(s[:, i] .^ 2) for i in 1:5]
5-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0000000000000002

In the forward transform code, remainder_sq is initialised at one and then the squares of each element going down column j are successively subtracted, so remainder_sq is really a sum of squares of elements not yet seen.

@inbounds for j in 2:K
y[idx] = atanh(W[1, j])
idx += 1
remainder_sq = 1 - W[1, j]^2
for i in 2:(j - 1)
z = W[i, j] / sqrt(remainder_sq)
y[idx] = atanh(z)
remainder_sq -= W[i, j]^2
idx += 1
end
end

Now, in principle, because z^2 = W[i, j]^2 / (sum of W[i:end, j]^2), there is no way that z^2 can be larger than 1.

However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element W[j-1, j] is very small. This doesn't tend to happen when W is sampled from LKJCholesky, but it can happen when W is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.

A proposed fix, instead of subtracting successive squares from 1, could just declare remainder_sq to be that sum:

    @inbounds for j in 2:K
-       remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
+           remainder_sq = sum(W[i:end, j] .^ 2)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
-           remainder_sq -= W[i, j]^2
            idx += 1
        end
    end

(In practice, I shortcircuited the sqrt by using norm(v) instead of sum(v .^ 2).)

Now, because W[i, j] ^ 2 is part of that sum, z can now no longer be larger than 1, and atanh doesn't throw a DomainError.

Setup code for this comment

Setup code
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using LogExpFunctions

# Using the invlink definition from this PR
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2
function _inv_link_chol_lkj_new(y::AbstractVector)
    LinearAlgebra.require_one_based_indexing(y)
    K = _triu1_dim_from_length(length(y))
    W = similar(y, K, K)
    T = float(eltype(W))
    logJ = zero(T)
    idx = 1
    @inbounds for j in 1:K
        log_remainder = zero(T)  # log of proportion of unit vector remaining
        for i in 1:(j - 1)
            z = tanh(y[idx])
            W[i, j] = z * exp(log_remainder)
            log_remainder -= LogExpFunctions.logcosh(y[idx])
            logJ += log_remainder
            idx += 1
        end
        logJ += log_remainder
        W[j, j] = exp(log_remainder)
        for i in (j + 1):K
            W[i, j] = 0
        end
    end
    return W, logJ
end

# Existing link implementation
function _link_chol_lkj_from_upper_old(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    idx = 1
    @inbounds for j in 2:K
        y[idx] = atanh(W[1, j])
        idx += 1
        remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
            remainder_sq -= W[i, j]^2
            idx += 1
        end
    end
    return y
end

# New proposal
function _link_chol_lkj_from_upper_new(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    idx = 1
    @inbounds for j in 2:K
        y[idx] = atanh(W[1, j])
        idx += 1
        for i in 2:(j - 1)
            remainder = norm(W[i:end, j])
            z = W[i, j] / remainder
            y[idx] = atanh(z)
            idx += 1
        end
    end
    return y
end

function plot_maes(samples)
    log_mae_old = log10.([sample[1] for sample in samples])
    log_mae_new = log10.([sample[2] for sample in samples])
    scatter(log_mae_old, log_mae_new, label="")
    lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
    lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
    plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
    xlabel!("log10(maximum abs error old)")
    ylabel!("log10(maximum abs error new)")
end

function test_forward_bijector(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    Random.seed!(468)
    samples = map(1:500) do _
        x = rand(dist)
        x_again_old = _inv_link_chol_lkj_new(f_old(x.U))[1]
        x_again_new = _inv_link_chol_lkj_new(f_new(x.U))[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x.U - x_again_old)), maximum(abs.(x.U - x_again_new)))
    end
    return samples
end

Impacts of this change

First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:

julia> plot_maes(test_forward_bijector(_link_chol_lkj_from_upper_old, _link_chol_lkj_from_upper_new))

bijector_forward_typical

On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:

julia> y = rand(Random.Xoshiro(468), 10) * 16;

julia> x = _inv_link_chol_lkj_new(y)[1];

julia> y_old = _link_chol_lkj_from_upper_old(x)
ERROR: DomainError with 1.000207932997037:
atanh(x) is only defined for |x|  1.
Stacktrace:
 [1] atanh_domain_error(x::Float64)
   @ Base.Math ./special/hyperbolic.jl:240
 [2] atanh
   @ ./special/hyperbolic.jl:256 [inlined]
 [3] _link_chol_lkj_from_upper_old(W::Matrix{Float64})
   @ Main ./REPL[29]:12
 [4] top-level scope
   @ REPL[78]:1

julia> y_new = _link_chol_lkj_from_upper_new(x)
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371709019
 12.606351374271206
  8.239542965781226
  7.897855159032619
  6.885928358454504
  7.201266901997009
  4.588778566499247
  5.507106236959028
 11.582258189742753

Remaining concerns 1: performance

It's bad.

julia> using Chairmarks

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_old(_.U)
Benchmark: 2915 samples with 285 evaluations
 min    94.007 ns (2 allocs: 144 bytes)
 median 102.193 ns (2 allocs: 144 bytes)
 mean   111.010 ns (2 allocs: 144 bytes, 0.25% gc time)
 max    8.571 μs (2 allocs: 144 bytes, 97.63% gc time)

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_new(_.U)
Benchmark: 2632 samples with 90 evaluations
 min    319.444 ns (14 allocs: 672 bytes)
 median 335.189 ns (14 allocs: 672 bytes)
 mean   387.088 ns (14 allocs: 672 bytes, 0.15% gc time)
 max    40.176 μs (14 allocs: 672 bytes, 98.47% gc time)

Remaining concerns 2: accuracy on pathological samples

It's not great, but considering that the existing implementation errors, this is still a net win.

julia> y = rand(Random.Xoshiro(468), 10) * 16
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371708977
 12.606352576618578
  8.239542965660522
  7.897855158738416
  6.885928358486035
  7.201266902006305
  4.588778566499414
  5.507106236959235
 11.582258611360368

julia> x = _inv_link_chol_lkj_new(y)[1];

julia> y_new = _link_chol_lkj_from_upper_new(x)
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371709019
 12.606351374271206
  8.239542965781226
  7.897855159032619
  6.885928358454504
  7.201266901997009
  4.588778566499247
  5.507106236959028
 11.582258189742753

julia> maximum(abs.(y - y_new))
1.2023473718869582e-6

Hybrid implementation?

One option to improve performance could be to use the default implementation, unless z > sqrt(remainder_sq), in which case we would recalculate remainder_sq by summation rather than subtraction. This introduces a much smaller overhead:

function _link_chol_lkj_from_upper_hybrid(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    idx = 1
    @inbounds for j in 2:K
        y[idx] = atanh(W[1, j])
        idx += 1
        remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
            remainder = sqrt(remainder_sq)
            if W[i, j] > remainder
                # Recalculate remainder
                z = W[i, j] / norm(W[i:end, j])
            else
                z = W[i, j] / remainder
            end
            y[idx] = atanh(z)
            remainder_sq -= W[i, j]^2
            idx += 1
        end
    end
    return y
end
julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_hybrid(_.U)
Benchmark: 2816 samples with 236 evaluations
 min    117.936 ns (2 allocs: 144 bytes)
 median 126.059 ns (2 allocs: 144 bytes)
 mean   137.724 ns (2 allocs: 144 bytes, 0.34% gc time)
 max    11.727 μs (2 allocs: 144 bytes, 97.91% gc time)

(from above, the original implementation was 111 ns, the pure new implementation with recalculation on every step is 387 ns)

Unfortunately, this hybrid implementation is numerically rather unstable, and using it could therefore introduce silent errors:

julia> y = rand(Random.Xoshiro(468), 10) * 16
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371708977
 12.606352576618578
  8.239542965660522
  7.897855158738416
  6.885928358486035
  7.201266902006305
  4.588778566499414
  5.507106236959235
 11.582258611360368

julia> x = _inv_link_chol_lkj_new(y)[1];

julia> y_new = _link_chol_lkj_from_upper_hybrid(x)
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371709019
 12.60765841570706
  8.239542965781226
  7.898065159405475
  6.885928358454504
  7.201266901997009
  4.588778545333437
  5.5067849173327055
  4.368093957587587

julia> maximum(abs.(y - y_new))
7.214164653772781

@penelopeysm
Copy link
Member Author

penelopeysm commented Dec 1, 2024

CI is failing because ForwardDiff calculates quite a different Jacobian in this test.

Since the new implementation still works correctly on roundtrip transformation, and this PR doesn't touch anything to do with Jacobian calculations, I wonder if this is a ForwardDiff bug (?)

@testset "LKJCholesky" begin
@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)
x = rand(dist)
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end
end

Repro:

using Bijectors
using ForwardDiff: ForwardDiff
using LinearAlgebra: logabsdet, I, Cholesky
using Random

uplo = :L
dist = LKJCholesky(3, 1, uplo)
x = rand(Xoshiro(468), dist)
inds = [
    LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
    (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logabsdet(J)[1]

Before this PR:

julia> x = rand(Xoshiro(468), dist)
Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
  1.0                  
 -0.23039    0.973098   
  0.288231  -0.899424  0.328572

julia> logabsdet(J)[1]
2.3239053137427703

With this PR:

julia> logabsdet(J)[1]
0.184638090189601

@@ -297,11 +297,10 @@ function _link_chol_lkj(W::AbstractMatrix)
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

@inbounds for j in 1:K
remainder_sq = one(eltype(W))
for i in 1:(j - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could save a few operations by reversing the loop and summing remainder_sq incrementally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants