Cholesky numerical stability: Forward transform #357
+4
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:
Introduction
The forward transform acts on an upper triangular matrix,
W
, which is supposed to have unit vectors for each column, i.e.sum(W[:, j] .^ 2)
should be 1 for eachj
:In the forward transform code,
remainder_sq
is initialised at one and then the squares of each element going down columnj
are successively subtracted, soremainder_sq
is really a sum of squares of elements not yet seen.Bijectors.jl/src/bijectors/corr.jl
Lines 321 to 331 in f52a9c5
Now, in principle, because
z^2 = W[i, j]^2 / (sum of W[i:end, j]^2)
, there is no way thatz^2
can be larger than 1.However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element
W[j-1, j]
is very small. This doesn't tend to happen whenW
is sampled fromLKJCholesky
, but it can happen whenW
is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.A proposed fix, instead of subtracting successive squares from 1, could just declare
remainder_sq
to be that sum:(In practice, I shortcircuited the sqrt by using
norm(v)
instead ofsum(v .^ 2)
.)Now, because
W[i, j] ^ 2
is part of that sum,z
can now no longer be larger than 1, and atanh doesn't throw a DomainError.Setup code for this comment
Setup code
Impacts of this change
First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:
On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:
Remaining concerns 1: performance
It's bad.
Remaining concerns 2: accuracy on pathological samples
It's not great, but considering that the existing implementation errors, this is still a net win.
Hybrid implementation?
One option to improve performance could be to use the default implementation, unless
z > sqrt(remainder_sq)
, in which case we would recalculateremainder_sq
by summation rather than subtraction. This introduces a much smaller overhead:(from above, the original implementation was 111 ns, the pure new implementation with recalculation on every step is 387 ns)
Unfortunately, this hybrid implementation is numerically rather unstable, and using it could therefore introduce silent errors: