Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cholesky numerical stability: inverse transform #356

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 30, 2024

Note, I only have a tiny bit of experience with numerical programming (on matrix exponentials) so this is definitely Not My Area of Expertise and I might be doing something horribly wrong

This PR attempts to improve the numerical stability of the inverse transform for Cholesky matrices. For the forward transform, see this PR: #357


Description

This PR replaces log1p(-z^2) / 2 where z = tanh(y[idx]) with IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx]) (which is the same mathematical expression) in _inv_link_chol_lkj:

$$\begin{align} z &= \frac{e^y - e^{-y}}{e^y + e^{-y}} \\ \frac{1}{2}\log(1 - z^2) &= \frac{1}{2}\log\left[1 - \left(\frac{e^y - e^{-y}}{e^y + e^{-y}}\right)^2\right] \\ &= \frac{1}{2}\log\left[1 - \left(\frac{e^{2y} - 2 + e^{-2y}}{e^{2y} + 2 + e^{-2y}}\right)\right] \\ &= \frac{1}{2}\log\left[\frac{4}{e^{2y} + 2 + e^{-2y}}\right] \\ &= \frac{1}{2}\log\left[\left(\frac{2}{e^{y} + e^{-y}}\right)^2\right] \\ &= \log\left[\frac{2}{e^{y} + e^{-y}}\right] \quad\quad\quad\text{(note 1)}\\ &= \log\left[\frac{2e^{y}}{e^{2y} + 1}\right] \\ &= \log(2) + e^y - \log(e^{2y} + 1) \\ \end{align}$$

(Note 1: I tried implementing this directly as log(2 / (exp(y[idx]) + exp(-y[idx]))), but this was worse in terms of performance.)

Note2: This has now been replaced directly with a call to LogExpFunctions.logcosh, which does the same calculation (see https://github.com/JuliaStats/LogExpFunctions.jl/blob/289114f535827c612ce10c01b8dec9d3a55e4d15/src/basicfuns.jl#L132-L135). Consequently the numerical stability and performance are the same as below.

Accuracy 1

First, to make sure there aren't any regressions, we'll:

  1. Sample from a Cholesky distribution
  2. Transform it with the existing bijector
  3. Un-transform it with both the old and the new implementation
  4. Calculate the max absolute error introduced by the roundtrip transformation
  5. Plot the error with the new implementation vs the error with the old implementation
Code to generate plot
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using IrrationalConstants
using LogExpFunctions

_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2

function _inv_link_chol_lkj_old(y::AbstractVector)
    LinearAlgebra.require_one_based_indexing(y)
    K = _triu1_dim_from_length(length(y))
    W = similar(y, K, K)
    T = float(eltype(W))
    logJ = zero(T)
    idx = 1
    @inbounds for j in 1:K
        log_remainder = zero(T)  # log of proportion of unit vector remaining
        for i in 1:(j - 1)
            z = tanh(y[idx])
            W[i, j] = z * exp(log_remainder)
            log_remainder += log1p(-z^2) / 2
            logJ += log_remainder
            idx += 1
        end
        logJ += log_remainder
        W[j, j] = exp(log_remainder)
        for i in (j + 1):K
            W[i, j] = 0
        end
    end
    return W, logJ
end

function _inv_link_chol_lkj_new(y::AbstractVector)
    LinearAlgebra.require_one_based_indexing(y)
    K = _triu1_dim_from_length(length(y))
    W = similar(y, K, K)
    T = float(eltype(W))
    logJ = zero(T)
    idx = 1
    @inbounds for j in 1:K
        log_remainder = zero(T)  # log of proportion of unit vector remaining
        for i in 1:(j - 1)
            z = tanh(y[idx])
            W[i, j] = z * exp(log_remainder)
            log_remainder -= LogExpFunctions.logcosh(y[idx])
            logJ += log_remainder
            idx += 1
        end
        logJ += log_remainder
        W[j, j] = exp(log_remainder)
        for i in (j + 1):K
            W[i, j] = 0
        end
    end
    return W, logJ
end

function plot_maes(samples)
    log_mae_old = log10.([sample[1] for sample in samples])
    log_mae_new = log10.([sample[2] for sample in samples])
    scatter(log_mae_old, log_mae_new, label="")
    lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
    lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
    plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
    xlabel!("log10(maximum abs error old)")
    ylabel!("log10(maximum abs error new)")
end

function test_inverse_bijector(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    b = bijector(dist)
    Random.seed!(468)
    samples = map(1:500) do _
        x = rand(dist)
        y = b(x)
        x_true = Matrix{Float64}(x.U) # Convert to full matrix
        x_old = f_old(y)[1]
        x_new = f_new(y)[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x_true - x_old)), maximum(abs.(x_true - x_new)))
    end
    return samples
end
plot_maes(test_inverse_bijector(_inv_link_chol_lkj_old, _inv_link_chol_lkj_new))
savefig("bijector_typical.png")

bijector_typical

There isn't really much between the two implementations, sometimes the old one is better, sometimes the new one is better. In any case, the differences are very small so I think the new implementation can be said to almost break even, although I do think the old implementation is very slightly better.

Accuracy 2

However, when sampling in the unconstrained space, there's no guarantee that the resulting sample will resemble anything like the samples obtained via a forward transformation. This leads to issues like #279.

To test out the numerical stability of invlinking random transformed samples, we can:

  1. Generate a random transformed sample.
  2. Invlink it with the old method, but using arbitrary precision floats. This is our ground truth.
  3. Invlink it with the old method, with Float64 precision
  4. Invlink it with the new method, with Float64 precision
  5. Compare the errors as above
Code to generate plot
function test_inverse_bijector_unconstrained(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    Random.seed!(468)
    samples = map(1:500) do _
        y = rand(dist.d * (dist.d - 1) ÷ 2) * 10
        x_true = f_old(Vector{BigFloat}(y))[1]
        x_old = f_old(y)[1]
        x_new = f_new(y)[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x_true - x_old)), maximum(abs.(x_true - x_new)))
    end
    return samples
end
plot_maes(test_inverse_bijector_unconstrained(_inv_link_chol_lkj_old, _inv_link_chol_lkj_new))
savefig("bijector_unconstrained.png")

bijector_unconstrained

As can be seen, the new method leads to much smaller errors (consistently around the magnitude of eps() ~ 1e-16) whereas the old method often has errors that are several orders of magnitude larger.


Performance

julia> using Chairmarks

julia> @be (rand(10) * 10) _inv_link_chol_lkj_old
Benchmark: 2592 samples with 99 evaluations
 min    245.788 ns (2 allocs: 272 bytes)
 median 269.788 ns (2 allocs: 272 bytes)
 mean   333.962 ns (2 allocs: 272 bytes, 0.08% gc time)
 max    100.998 μs (2 allocs: 272 bytes, 99.15% gc time)

julia> @be (rand(10) * 10) _inv_link_chol_lkj_new
Benchmark: 2722 samples with 90 evaluations
 min    300.000 ns (2 allocs: 272 bytes)
 median 310.656 ns (2 allocs: 272 bytes)
 mean   375.287 ns (2 allocs: 272 bytes, 0.07% gc time)
 max    107.588 μs (2 allocs: 272 bytes, 99.27% gc time)

What next

Note that this issue doesn't actually fully solve #279. That issue arises not because of the inverse transformation, but rather because of the forward transformation (in the call to logabsdetjac). This is a result of more numerical instabilities in other functions, specifically the linking one. I've been having a go at them, but haven't had much success (yet...?).

using Bijectors
using Distributions

θ_unconstrained = [
	-1.9887091960524537,
	-13.499454444466279,
	-0.39328331954134665,
	-4.426097270849902,
	13.101175413857023,
	7.66647404712346,
	9.249285786544894,
	4.714877413573335,
	6.233118490809442,
	22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)

# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.

However, this PR does allow for the Jacobian of the inverse to be calculated. On master this would return -Inf:

julia> Bijectors.logabsdetjac(inverse(b), θ_unconstrained)
-225.9679826839954

@penelopeysm penelopeysm marked this pull request as draft November 30, 2024 17:40
@penelopeysm penelopeysm force-pushed the py/chol-numerical branch 2 times, most recently from 164a33c to 9d11ba4 Compare November 30, 2024 18:00
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT it's even simpler since $$\log(1 - \tanh^2(x))/2 = \log(\mathrm{sech}^2(x)) / 2 = \log(\mathrm{sech}(x)) = \log(1 / \cosh(x)) = -\log(\cosh(x))$$, and LogExpFunctions.logcosh is supposed to provide a numerically stable and efficient implementation of $$\log(\cosh(\cdot))$$ (if there are problems they should be considered bugs).

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
@penelopeysm
Copy link
Member Author

👀 I'll give that a spin

@penelopeysm
Copy link
Member Author

Looking under the hood logcosh is implemented the same way as above, but the single function call is great 👍

src/bijectors/corr.jl Outdated Show resolved Hide resolved
@@ -495,8 +490,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
@inbounds for j in 2:K
tmp = zero(result)
for _ in 1:(j - 1)
z = tanh(y[idx])
logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx])))
logz = -2 * LogExpFunctions.logcosh(y[idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name logz is a bit meaningless now I guess 🙂

@penelopeysm penelopeysm changed the title Attempt to improve Cholesky numerical stability Cholesky numerical stability: inverse transform Dec 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants