diff --git a/scico/loss.py b/scico/loss.py index 74f62d482..2c4e98c40 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -8,7 +8,7 @@ """Loss function classes.""" - +import warnings from copy import copy from functools import wraps from typing import Callable, Optional, Union @@ -67,7 +67,6 @@ def __init__( :meth:`__call__` and :meth:`prox` (where appropriate) must be defined in a derived class. scale: Scaling parameter. Default: 1.0. - """ self.y = ensure_on_device(y) if A is None: @@ -371,7 +370,7 @@ def prox( return x -def _cbrt(x): +def _cbrt(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Compute the cube root of the argument. The two standard options for computing the cube root of an array are @@ -382,35 +381,130 @@ def _cbrt(x): negative real values. Args: - x: Input array + x: Input array. Returns: Array of cube roots of input `x`. """ - s = snp.where(snp.abs(snp.angle(x)) <= 3 * snp.pi / 4, 1, -1) + s = snp.where(snp.abs(snp.angle(x)) <= 2 * snp.pi / 3, 1, -1) return s * (s * x) ** (1 / 3) -def _dep_cubic_root(p, q): - r"""Compute the positive real root of a depressed cubic equation. +def _check_root( + x: Union[JaxArray, BlockArray], + p: Union[JaxArray, BlockArray], + q: Union[JaxArray, BlockArray], + tol: float = 1e-4, +): + """Check the precision of a cubic equation solution. + + Check the precision of an array of depressed cubic equation solutions, + issuing a warning if any of the errors exceed a specified tolerance. + + Args: + x: Array of roots of a depressed cubic equation. + p: Array of linear parameters of a depressed cubic equation. + q: Array of constant parameters of a depressed cubic equation. + tol: Expected tolerance for solution precision. + """ + err = snp.abs(x**3 + p * x + q) + if not snp.allclose(err, 0, atol=tol): + idx = snp.argmax(err) + msg = ( + "Low precision in root calculation. Worst error is " + f"{err.ravel()[idx]:.3e} for p={p.ravel()[idx]} and q={q.ravel()[idx]}" + ) + warnings.warn(msg) + + +def _dep_cubic_root( + p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArray] +) -> Union[JaxArray, BlockArray]: + r"""Compute a real root of a depressed cubic equation. A depressed cubic equation is one that can be written in the form .. math:: - x^3 + px + q \;. + x^3 + px + q = 0 \;. + + The determinant is + + .. math:: + \Delta = (q/2)^2 + (p/3)^3 \;. + + When :math:`\Delta > 0` this equation has one real root and two + complex (conjugate) roots, when :math:`\Delta = 0`, all three roots + are real, with at least two being equal, and when :math:`\Delta < 0`, + all roots are real and unequal. + + According to Vieta's formulas, the roots :math:`x_0, x_1`, and + :math:`x_2` of this equation satisfy + + .. math:: + x_0 + x_1 + x_2 &= 0 \\ + x_0 x_1 + x_0 x_2 + x_2 x_3 &= p \\ + x_0 x_1 x_2 &= -q \;. + + Therefore, when :math:`q` is negative, the equation has a single real + positive root since at least one root must be negative for their sum + to be zero, and their product could not be positive if only one root + were zero. This function always returns a real root; when :math:`q` + is negative, it returns the single positive root. + + The solution is computed using + `Vieta's substitution `__, + + .. math:: + w = x - \frac{p}{3w} \;, - This function finds the positive real root of such an equation via - `Cardano's method `__, - for `p` and `q` such that there is a single positive real root - (see Sec. 3.C of :cite:`soulez-2016-proximity`). + which reduces the depressed cubic equation to + + .. math:: + w^3 - \frac{p^3}{27w^3} + q = 0\;, + + which can be expressed as a quadratic equation in :math:`w^3` by + multiplication by :math:`w^3`, leading to + + .. math:: + w^3 = -\frac{q}{2} \pm \sqrt{\frac{q^2}{4} + \frac{p^3}{27}} \;. + + Note that the multiplication by :math:`w^3` introduces a spurious + solution at zero in the case :math:`p = 0`, which must be handled + separately as + + .. math:: + w^3 = -q \;. + + Despite taking this into account, very poor numerical precision can + be obtained when :math:`p` is small but non-zero since, in this case + + .. math:: + \sqrt{\Delta} = \sqrt{(q/2)^2 + (p/3)^3} \approx q/2 \;, + + so that an incorrect solutions :math:`w^3 = 0` or :math:`w^3 = -q` + are obtained, depending on the choice of sign in the equation for + :math:`w^3`. + + An alternative derivation leads to the equation + + .. math:: + x = \sqrt[3]{-q/2 + \sqrt{\Delta}} + \sqrt[3]{-q/2 - \sqrt{\Delta}} + + for the real root, but this is also prone to severe numerical errors + in single precision arithmetic. + + Args: + p: Array of :math:`p` values. + q: Array of :math:`q` values. + + Returns: + Array of real roots of the cubic equation. """ - q2 = q / 2 - Δ = q2**2 + (p / 3) ** 3 - Δrt = snp.sqrt(Δ + 0j) - u3, v3 = -q2 + Δrt, -q2 - Δrt - u, v = _cbrt(u3), _cbrt(v3) - r = (u + v).real - assert snp.allclose(snp.abs(r**3 + p * r + q), 0, atol=1e-4) + Δ = (q**2) / 4.0 + (p**3) / 27.0 + w3 = snp.where(snp.abs(p) <= 1e-7, -q, -q / 2.0 + snp.sqrt(Δ + 0j)) + w = _cbrt(w3) + r = (w - no_nan_divide(p, 3 * w)).real + _check_root(r, p, q) return r diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index ae1ea901c..c606fb285 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -13,7 +13,7 @@ import scico.numpy as snp from scico import functional, linop, loss from scico.array import complex_dtype -from scico.random import randn +from scico.random import randn, uniform class TestLoss: @@ -208,3 +208,19 @@ def test_prox(self, loss_tuple): pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v + + +def test_cubic_root(): + N = 10000 + p, key = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, seed=1234) + q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0, key=key) + # Avoid cases of very poor numerical precision + p = p.at[snp.logical_and(snp.abs(p) < 2, q > 5e-2 * snp.abs(p))].set(1e1) + r = loss._dep_cubic_root(p, q) + err = snp.abs(r**3 + p * r + q) + assert err.max() < 2e-4 + # Test + p = snp.array(1e-4, dtype=snp.float32) + q = snp.array(1e1, dtype=snp.float32) + with pytest.warns(UserWarning): + r = loss._dep_cubic_root(p, q)