Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix L21Norm handling of BlockArray input #506

Merged
merged 3 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,34 +194,37 @@ class L21Norm(Functional):

The norm generalizes to more dimensions by first computing the
:math:`\ell_2` norm along one or more (user-specified) axes,
followed by a sum over all remaining axes.

For `BlockArray` inputs, the :math:`\ell_2` norm follows the
reduction rules described in :class:`BlockArray`.
followed by a sum over all remaining axes. :class:`.BlockArray` inputs
require parameter `l2_axis` to be ``None``, in which case the
:math:`\ell_2` norm is computed over each block.

A typical use case is computing the isotropic total variation norm.
"""

has_eval = True
has_prox = True

def __init__(self, l2_axis: Union[int, Tuple] = 0):
def __init__(self, l2_axis: Union[None, int, Tuple] = 0):
r"""
Args:
l2_axis: Axis/axes over which to take the l2 norm. Default: 0.
l2_axis: Axis/axes over which to take the l2 norm. Required
to be ``None`` for :class:`.BlockArray` inputs to be
supported.
"""
self.l2_axis = l2_axis

@staticmethod
def _l2norm(
x: Union[Array, BlockArray], axis: Union[int, Tuple], keepdims: Optional[bool] = False
x: Union[Array, BlockArray], axis: Union[None, int, Tuple], keepdims: Optional[bool] = False
):
r"""Return the :math:`\ell_2` norm of an array."""
return snp.sqrt(snp.sum(snp.abs(x) ** 2, axis=axis, keepdims=keepdims))
return snp.sqrt((snp.abs(x) ** 2).sum(axis=axis, keepdims=keepdims))

def __call__(self, x: Union[Array, BlockArray]) -> float:
if isinstance(x, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer parameter l2_axis must be None for BlockArray input.")
l2 = L21Norm._l2norm(x, axis=self.l2_axis)
return snp.abs(l2).sum()
return snp.sum(snp.abs(l2))

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
Expand Down Expand Up @@ -249,6 +252,8 @@ def prox(
kwargs: Additional arguments that may be used by derived
classes.
"""
if isinstance(v, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer parameter l2_axis must be None for BlockArray input.")
length = L21Norm._l2norm(v, axis=self.l2_axis, keepdims=True)
direction = no_nan_divide(v, length)

Expand Down
22 changes: 0 additions & 22 deletions scico/test/functional/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,6 @@ def foo(c):
np.testing.assert_allclose(non_pmap, pmapped)


@pytest.mark.parametrize("axis", [0, 1, (0, 2)])
def test_l21norm(axis):
x = np.ones((3, 4, 5))
if isinstance(axis, int):
l2axis = (axis,)
else:
l2axis = axis
l2shape = [x.shape[k] for k in l2axis]
l1axis = tuple(set(range(len(x))) - set(l2axis))
l1shape = [x.shape[k] for k in l1axis]

l21ana = np.sqrt(np.prod(l2shape)) * np.prod(l1shape)
F = functional.L21Norm(l2_axis=axis)
l21num = F(x)
np.testing.assert_allclose(l21ana, l21num, rtol=1e-5)

l2ana = np.sqrt(np.prod(l2shape))
prxana = (l2ana - 1.0) / l2ana * x
prxnum = F.prox(x, 1.0)
np.testing.assert_allclose(prxana, prxnum, rtol=1e-5)


def test_scalar_aggregation():
f = functional.L2Norm()
g = 2.0 * f
Expand Down
44 changes: 44 additions & 0 deletions scico/test/functional/test_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np

import pytest

import scico.numpy as snp
from scico import functional


@pytest.mark.parametrize("axis", [0, 1, (0, 2)])
def test_l21norm(axis):
x = np.ones((3, 4, 5))
if isinstance(axis, int):
l2axis = (axis,)
else:
l2axis = axis
l2shape = [x.shape[k] for k in l2axis]
l1axis = tuple(set(range(len(x))) - set(l2axis))
l1shape = [x.shape[k] for k in l1axis]

l21ana = np.sqrt(np.prod(l2shape)) * np.prod(l1shape)
F = functional.L21Norm(l2_axis=axis)
l21num = F(x)
np.testing.assert_allclose(l21ana, l21num, rtol=1e-5)

l2ana = np.sqrt(np.prod(l2shape))
prxana = (l2ana - 1.0) / l2ana * x
prxnum = F.prox(x, 1.0)
np.testing.assert_allclose(prxana, prxnum, rtol=1e-5)


def test_l2norm_blockarray():
xa = np.random.randn(2, 3, 4)
xb = snp.blockarray((xa[0], xa[1]))

fa = functional.L21Norm(l2_axis=(1, 2))
fb = functional.L21Norm(l2_axis=None)

np.testing.assert_allclose(fa(xa), fb(xb), rtol=1e-6)

ya = fa.prox(xa)
yb = fb.prox(xb)

np.testing.assert_allclose(ya[0], yb[0], rtol=1e-6)
np.testing.assert_allclose(ya[1], yb[1], rtol=1e-6)
Loading