From 74af37b9bb539d620b79aecb950a47546a91fbe2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 14 Apr 2022 20:08:25 -0600 Subject: [PATCH 1/6] Switch to more numerically stable calculation of root of cubic equation --- scico/loss.py | 64 +++++++++++++++++++++++++-------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 74f62d482..74e3df3fb 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -67,7 +67,6 @@ def __init__( :meth:`__call__` and :meth:`prox` (where appropriate) must be defined in a derived class. scale: Scaling parameter. Default: 1.0. - """ self.y = ensure_on_device(y) if A is None: @@ -371,46 +370,47 @@ def prox( return x -def _cbrt(x): - """Compute the cube root of the argument. +def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArray]): + r"""Compute a real root of a depressed cubic equation. - The two standard options for computing the cube root of an array are - :func:`numpy.cbrt`, or raising to the power of (1/3), i.e. `x ** (1/3)`. - The former cannot be used for complex values, and the latter returns - a complex root of a negative real value. This functions can be used - for both real and complex values, and returns the real root of - negative real values. + A depressed cubic equation is one that can be written in the form - Args: - x: Input array + .. math:: + x^3 + px + q \;. - Returns: - Array of cube roots of input `x`. - """ - s = snp.where(snp.abs(snp.angle(x)) <= 3 * snp.pi / 4, 1, -1) - return s * (s * x) ** (1 / 3) + When :math:`\Delta = (q/2)^2 + (p/3)^3 > 0` this equation has one + real root and two complex (conjugate) roots. When :math:`\Delta = 0`, + all three roots are real, with at least two being equal, and when + :math:`\Delta < 0`, all roots are real and unequal. According to + Vieta's formulas, the roots :math:`x_0, x_1`, and :math:`x_2` of this + equation satisfy + .. math:: + x_0 + x_1 + x_2 &= 0 \\ + x_0 x_1 + x_0 x_2 + x_2 x_3 &= p \\ + x_0 x_1 x_2 &= -q \;. -def _dep_cubic_root(p, q): - r"""Compute the positive real root of a depressed cubic equation. + Therefore, when :math:`q` is negative, the equation has a single real + positive root since at least one must be negative for their sum to + be zero, and their product could not be positive if only one were + zero. This function always returns a real root; when :math:`q` is + negative, it returns the single positive root. - A depressed cubic equation is one that can be written in the form + The solution is computed using + `Vieta's substitution `__. - .. math:: - x^3 + px + q \;. + Args: + p: Array of :math:`p` values. + q: Array of :math:`q` values. - This function finds the positive real root of such an equation via - `Cardano's method `__, - for `p` and `q` such that there is a single positive real root - (see Sec. 3.C of :cite:`soulez-2016-proximity`). + Returns: + Array of real roots of the cubic equation. """ - q2 = q / 2 - Δ = q2**2 + (p / 3) ** 3 - Δrt = snp.sqrt(Δ + 0j) - u3, v3 = -q2 + Δrt, -q2 - Δrt - u, v = _cbrt(u3), _cbrt(v3) - r = (u + v).real - assert snp.allclose(snp.abs(r**3 + p * r + q), 0, atol=1e-4) + Δ = (q**2) / 4.0 + (p**3) / 27.0 + w3 = -q / 2.0 + snp.sqrt(Δ + 0j) + w = w3 ** (1 / 3) + r = (w - no_nan_divide(p, 3 * w)).real + assert snp.allclose(snp.abs(r**3 + p * r + q), 0, atol=1e-5) return r From f1a949686ed1e2cd7d0139ed51f1bd830f0bcd91 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 15 Apr 2022 08:48:19 -0600 Subject: [PATCH 2/6] Improve docs --- scico/loss.py | 52 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 74e3df3fb..8e6ee1fc6 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -376,14 +376,20 @@ def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArra A depressed cubic equation is one that can be written in the form .. math:: - x^3 + px + q \;. + x^3 + px + q = 0 \;. - When :math:`\Delta = (q/2)^2 + (p/3)^3 > 0` this equation has one - real root and two complex (conjugate) roots. When :math:`\Delta = 0`, - all three roots are real, with at least two being equal, and when - :math:`\Delta < 0`, all roots are real and unequal. According to - Vieta's formulas, the roots :math:`x_0, x_1`, and :math:`x_2` of this - equation satisfy + The determinant is + + .. math:: + \Delta = (q/2)^2 + (p/3)^3 \;. + + When :math:`\Delta > 0` this equation has one real root and two + complex (conjugate) roots, when :math:`\Delta = 0`, all three roots + are real, with at least two being equal, and when :math:`\Delta < 0`, + all roots are real and unequal. + + According to Vieta's formulas, the roots :math:`x_0, x_1`, and + :math:`x_2` of this equation satisfy .. math:: x_0 + x_1 + x_2 &= 0 \\ @@ -391,13 +397,35 @@ def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArra x_0 x_1 x_2 &= -q \;. Therefore, when :math:`q` is negative, the equation has a single real - positive root since at least one must be negative for their sum to - be zero, and their product could not be positive if only one were - zero. This function always returns a real root; when :math:`q` is - negative, it returns the single positive root. + positive root since at least one root must be negative for their sum + to be zero, and their product could not be positive if only one root + were zero. This function always returns a real root; when :math:`q` + is negative, it returns the single positive root. The solution is computed using - `Vieta's substitution `__. + `Vieta's substitution `__, + + .. math:: + w = x - \frac{p}{3w} \;, + + which reduces the depressed cubic equation to + + .. math:: + w^3 - \frac{p^3}{27w^3} + q = 0\;, + + which can be expressed as a quadratic equation in :math:`w^3` by + multiplication by :math:`w^3`, leading to + + .. math:: + w^3 = -\frac{q}{2} \pm \sqrt{\frac{q^2}{4} + \frac{p^3}{27}} \;. + + An alternative derivation leads to the equation + + .. math:: + x = \sqrt[3]{-q/2 + \sqrt{\Delta}} + \sqrt[3]{-q/2 - \sqrt{\Delta}} + + for the real root, but this is not suitable for use here due to severe + numerical errors in single precision arithmetic. Args: p: Array of :math:`p` values. From 853ad4fd94501927b0230130dcb6beb8af963a1c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 15 Apr 2022 15:43:16 -0600 Subject: [PATCH 3/6] Improve algorithm, docs, and tests --- scico/loss.py | 80 +++++++++++++++++++++++++++--- scico/test/functional/test_loss.py | 11 +++- 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 8e6ee1fc6..40ce23c49 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -8,7 +8,7 @@ """Loss function classes.""" - +import warnings from copy import copy from functools import wraps from typing import Callable, Optional, Union @@ -370,7 +370,56 @@ def prox( return x -def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArray]): +def _cbrt(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: + """Compute the cube root of the argument. + + The two standard options for computing the cube root of an array are + :func:`numpy.cbrt`, or raising to the power of (1/3), i.e. `x ** (1/3)`. + The former cannot be used for complex values, and the latter returns + a complex root of a negative real value. This functions can be used + for both real and complex values, and returns the real root of + negative real values. + + Args: + x: Input array. + + Returns: + Array of cube roots of input `x`. + """ + s = snp.where(snp.abs(snp.angle(x)) <= 2 * snp.pi / 3, 1, -1) + return s * (s * x) ** (1 / 3) + + +def _check_root( + x: Union[JaxArray, BlockArray], + p: Union[JaxArray, BlockArray], + q: Union[JaxArray, BlockArray], + tol: float = 1e-4, +): + """Check the precision of a cubic equation solution. + + Check the precision of an array of depressed cubic equation solutions, + issuing a warning if any of the errors exceed a specified tolerance. + + Args: + x: Array of roots of a depressed cubic equation. + p: Array of linear parameters of a depressed cubic equation. + q: Array of constant parameters of a depressed cubic equation. + tol: Expected tolerance for solution precision. + """ + err = snp.abs(x**3 + p * x + q) + if not snp.allclose(err, 0, atol=tol): + idx = snp.argmax(err) + msg = ( + "Low precision in root calculation. Worst error is " + f"{err.ravel()[idx]:.3e} for p={p.ravel()[idx]} and q={q.ravel()[idx]}" + ) + warnings.warn(msg) + + +def _dep_cubic_root( + p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArray] +) -> Union[JaxArray, BlockArray]: r"""Compute a real root of a depressed cubic equation. A depressed cubic equation is one that can be written in the form @@ -419,13 +468,30 @@ def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArra .. math:: w^3 = -\frac{q}{2} \pm \sqrt{\frac{q^2}{4} + \frac{p^3}{27}} \;. + Note that the multiplication by :math:`w^3` introduces a spurious + solution at zero in the case :math:`p = 0`, which must be handled + separately as + + .. math:: + w^3 = -q \;. + + Despite taking this into account, very poor numerical precision is + obtained when :math:`p` is small but non-zero since, in this case + + .. math:: + \sqrt{\Delta} = \sqrt{(q/2)^2 + (p/3)^3} \approx q/2 \;, + + so that an incorrect solutions :math:`w^3 = 0` or :math:`w^3 = -q` + are obtained, depending on the choice of sign in the equation for + :math:`w^3`. + An alternative derivation leads to the equation .. math:: x = \sqrt[3]{-q/2 + \sqrt{\Delta}} + \sqrt[3]{-q/2 - \sqrt{\Delta}} - for the real root, but this is not suitable for use here due to severe - numerical errors in single precision arithmetic. + for the real root, but this is also prone to severe numerical errors + in single precision arithmetic. Args: p: Array of :math:`p` values. @@ -435,10 +501,10 @@ def _dep_cubic_root(p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArra Array of real roots of the cubic equation. """ Δ = (q**2) / 4.0 + (p**3) / 27.0 - w3 = -q / 2.0 + snp.sqrt(Δ + 0j) - w = w3 ** (1 / 3) + w3 = snp.where(snp.abs(p) <= 1e-7, -q, -q / 2.0 + snp.sqrt(Δ + 0j)) + w = _cbrt(w3) r = (w - no_nan_divide(p, 3 * w)).real - assert snp.allclose(snp.abs(r**3 + p * r + q), 0, atol=1e-5) + _check_root(r, p, q) return r diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index ae1ea901c..7607a1b74 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -13,7 +13,7 @@ import scico.numpy as snp from scico import functional, linop, loss from scico.array import complex_dtype -from scico.random import randn +from scico.random import randn, uniform class TestLoss: @@ -208,3 +208,12 @@ def test_prox(self, loss_tuple): pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v + + +def test_cubic_root(): + N = 10000 + p, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0) + q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0) + r = loss._dep_cubic_root(p, q) + err = snp.abs(r**3 + p * r + q) + assert err.max() < 1e-4 From 15d33a9ff94d3747b7161ff34f502731408863d1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 19 Apr 2022 15:20:50 -0600 Subject: [PATCH 4/6] Fix use of scico.random.uniform --- scico/test/functional/test_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index 7607a1b74..981c788a8 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -212,8 +212,8 @@ def test_prox(self, loss_tuple): def test_cubic_root(): N = 10000 - p, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0) - q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0) + p, key = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, seed=1234) + q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, key=key) r = loss._dep_cubic_root(p, q) err = snp.abs(r**3 + p * r + q) assert err.max() < 1e-4 From b1f2f0714b7c5c938d38141508733e126c5bcfb7 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 19 Apr 2022 15:47:41 -0600 Subject: [PATCH 5/6] Improve docstring --- scico/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 40ce23c49..2c4e98c40 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -475,8 +475,8 @@ def _dep_cubic_root( .. math:: w^3 = -q \;. - Despite taking this into account, very poor numerical precision is - obtained when :math:`p` is small but non-zero since, in this case + Despite taking this into account, very poor numerical precision can + be obtained when :math:`p` is small but non-zero since, in this case .. math:: \sqrt{\Delta} = \sqrt{(q/2)^2 + (p/3)^3} \approx q/2 \;, From ab7e263f1010be0823cabb4ccc1f7167f6b5bb58 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 19 Apr 2022 15:48:00 -0600 Subject: [PATCH 6/6] Improve test --- scico/test/functional/test_loss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index 981c788a8..c606fb285 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -214,6 +214,13 @@ def test_cubic_root(): N = 10000 p, key = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, seed=1234) q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, key=key) + # Avoid cases of very poor numerical precision + p = p.at[snp.logical_and(snp.abs(p) < 2, q > 5e-2 * snp.abs(p))].set(1e1) r = loss._dep_cubic_root(p, q) err = snp.abs(r**3 + p * r + q) - assert err.max() < 1e-4 + assert err.max() < 2e-4 + # Test + p = snp.array(1e-4, dtype=snp.float32) + q = snp.array(1e1, dtype=snp.float32) + with pytest.warns(UserWarning): + r = loss._dep_cubic_root(p, q)