Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SquaredL2SquaredAbsLoss #278

Merged
merged 6 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 113 additions & 19 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

"""Loss function classes."""


import warnings
from copy import copy
from functools import wraps
from typing import Callable, Optional, Union
Expand Down Expand Up @@ -67,7 +67,6 @@ def __init__(
:meth:`__call__` and :meth:`prox` (where appropriate) must
be defined in a derived class.
scale: Scaling parameter. Default: 1.0.

"""
self.y = ensure_on_device(y)
if A is None:
Expand Down Expand Up @@ -371,7 +370,7 @@ def prox(
return x


def _cbrt(x):
def _cbrt(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""Compute the cube root of the argument.

The two standard options for computing the cube root of an array are
Expand All @@ -382,35 +381,130 @@ def _cbrt(x):
negative real values.

Args:
x: Input array
x: Input array.

Returns:
Array of cube roots of input `x`.
"""
s = snp.where(snp.abs(snp.angle(x)) <= 3 * snp.pi / 4, 1, -1)
s = snp.where(snp.abs(snp.angle(x)) <= 2 * snp.pi / 3, 1, -1)
return s * (s * x) ** (1 / 3)


def _dep_cubic_root(p, q):
r"""Compute the positive real root of a depressed cubic equation.
def _check_root(
x: Union[JaxArray, BlockArray],
p: Union[JaxArray, BlockArray],
q: Union[JaxArray, BlockArray],
tol: float = 1e-4,
):
"""Check the precision of a cubic equation solution.

Check the precision of an array of depressed cubic equation solutions,
issuing a warning if any of the errors exceed a specified tolerance.

Args:
x: Array of roots of a depressed cubic equation.
p: Array of linear parameters of a depressed cubic equation.
q: Array of constant parameters of a depressed cubic equation.
tol: Expected tolerance for solution precision.
"""
err = snp.abs(x**3 + p * x + q)
if not snp.allclose(err, 0, atol=tol):
idx = snp.argmax(err)
msg = (
"Low precision in root calculation. Worst error is "
f"{err.ravel()[idx]:.3e} for p={p.ravel()[idx]} and q={q.ravel()[idx]}"
)
warnings.warn(msg)


def _dep_cubic_root(
p: Union[JaxArray, BlockArray], q: Union[JaxArray, BlockArray]
) -> Union[JaxArray, BlockArray]:
r"""Compute a real root of a depressed cubic equation.

A depressed cubic equation is one that can be written in the form

.. math::
x^3 + px + q \;.
x^3 + px + q = 0 \;.

The determinant is

.. math::
\Delta = (q/2)^2 + (p/3)^3 \;.

When :math:`\Delta > 0` this equation has one real root and two
complex (conjugate) roots, when :math:`\Delta = 0`, all three roots
are real, with at least two being equal, and when :math:`\Delta < 0`,
all roots are real and unequal.

According to Vieta's formulas, the roots :math:`x_0, x_1`, and
:math:`x_2` of this equation satisfy

.. math::
x_0 + x_1 + x_2 &= 0 \\
x_0 x_1 + x_0 x_2 + x_2 x_3 &= p \\
x_0 x_1 x_2 &= -q \;.

Therefore, when :math:`q` is negative, the equation has a single real
positive root since at least one root must be negative for their sum
to be zero, and their product could not be positive if only one root
were zero. This function always returns a real root; when :math:`q`
is negative, it returns the single positive root.

The solution is computed using
`Vieta's substitution <https://mathworld.wolfram.com/CubicFormula.html>`__,

.. math::
w = x - \frac{p}{3w} \;,

This function finds the positive real root of such an equation via
`Cardano's method <https://en.wikipedia.org/wiki/Cubic_equation#Cardano's_formula>`__,
for `p` and `q` such that there is a single positive real root
(see Sec. 3.C of :cite:`soulez-2016-proximity`).
which reduces the depressed cubic equation to

.. math::
w^3 - \frac{p^3}{27w^3} + q = 0\;,

which can be expressed as a quadratic equation in :math:`w^3` by
multiplication by :math:`w^3`, leading to

.. math::
w^3 = -\frac{q}{2} \pm \sqrt{\frac{q^2}{4} + \frac{p^3}{27}} \;.

Note that the multiplication by :math:`w^3` introduces a spurious
solution at zero in the case :math:`p = 0`, which must be handled
separately as

.. math::
w^3 = -q \;.

Despite taking this into account, very poor numerical precision is
obtained when :math:`p` is small but non-zero since, in this case

.. math::
\sqrt{\Delta} = \sqrt{(q/2)^2 + (p/3)^3} \approx q/2 \;,

so that an incorrect solutions :math:`w^3 = 0` or :math:`w^3 = -q`
are obtained, depending on the choice of sign in the equation for
:math:`w^3`.

An alternative derivation leads to the equation

.. math::
x = \sqrt[3]{-q/2 + \sqrt{\Delta}} + \sqrt[3]{-q/2 - \sqrt{\Delta}}

for the real root, but this is also prone to severe numerical errors
in single precision arithmetic.

Args:
p: Array of :math:`p` values.
q: Array of :math:`q` values.

Returns:
Array of real roots of the cubic equation.
"""
q2 = q / 2
Δ = q2**2 + (p / 3) ** 3
Δrt = snp.sqrt(Δ + 0j)
u3, v3 = -q2 + Δrt, -q2 - Δrt
u, v = _cbrt(u3), _cbrt(v3)
r = (u + v).real
assert snp.allclose(snp.abs(r**3 + p * r + q), 0, atol=1e-4)
Δ = (q**2) / 4.0 + (p**3) / 27.0
w3 = snp.where(snp.abs(p) <= 1e-7, -q, -q / 2.0 + snp.sqrt(Δ + 0j))
w = _cbrt(w3)
r = (w - no_nan_divide(p, 3 * w)).real
_check_root(r, p, q)
return r


Expand Down
11 changes: 10 additions & 1 deletion scico/test/functional/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import scico.numpy as snp
from scico import functional, linop, loss
from scico.array import complex_dtype
from scico.random import randn
from scico.random import randn, uniform


class TestLoss:
Expand Down Expand Up @@ -208,3 +208,12 @@ def test_prox(self, loss_tuple):

pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v
pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v


def test_cubic_root():
N = 10000
p, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider seeding rng here so the test result is not stochastic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

q, _ = uniform(shape=(N,), dtype=snp.float32, minval=-10.0, maxval=10.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax gotcha: I think p and q will be equal if you make them this way. Need to get the key from the first call and pass it to the second.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed they were. Fixed.

r = loss._dep_cubic_root(p, q)
err = snp.abs(r**3 + p * r + q)
assert err.max() < 1e-4