Skip to content

Commit

Permalink
Resolve obstacles to jitting of some functionals (#467)
Browse files Browse the repository at this point in the history
* Resolve jit issues due to Python conditionals

* Add tests for jit-ability of functionals

* Explicitly skip BM3D and BM4D functional tests due to issues on OSX

* Fix previous change: edit applied to wrong list
  • Loading branch information
bwohlberg authored Nov 9, 2023
1 parent 092a917 commit a03aa56
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 26 deletions.
69 changes: 44 additions & 25 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ def prox(
classes.
"""
norm_v = norm(v)
if norm_v == 0:
return 0 * v
return snp.maximum(1 - lam / norm_v, 0) * v
return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)


class L21Norm(Functional):
Expand Down Expand Up @@ -283,16 +281,48 @@ def __init__(self, beta: float = 1.0):
def __call__(self, x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.abs(x)) - self.beta * norm(x)

@staticmethod
def _prox_vamx_ge_thresh(v, va, vs, alpha, beta):
u = snp.zeros(v.shape, dtype=v.dtype)
idx = va.ravel().argmax()
u = (
u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
return u

@staticmethod
def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx < (1.0 - beta) * alpha,
snp.zeros(v.shape, dtype=v.dtype),
L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta),
)

@staticmethod
def _prox_vamx_gt_alpha(v, va, vs, alpha, beta):
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
return u

@staticmethod
def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx > alpha,
L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta),
L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta),
)

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms.
Evaluate the proximal operator of the difference of :math:`\ell_1`
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x} \|_1 -
\beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note that this
is not a proximal operator according to the strict definition since
the loss function is non-convex.
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x}
\|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note
that this is not a proximal operator according to the strict
definition since the loss function is non-convex.
Args:
v: Input array :math:`\mb{v}`.
Expand All @@ -308,23 +338,12 @@ def prox(
vs = snp.exp(1j * snp.angle(v))
else:
vs = snp.sign(v)
if vamx > 0.0:
if vamx > alpha:
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
else:
u = snp.zeros(v.shape, dtype=v.dtype)
if vamx >= (1.0 - beta) * alpha:
idx = va.ravel().argmax()
u = (
u.ravel()
.at[idx]
.set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
else:
u = snp.zeros(v.shape, dtype=v.dtype)
return u

return snp.where(
vamx > 0.0,
L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta),
snp.zeros(v.shape, dtype=v.dtype),
)


class HuberNorm(Functional):
Expand Down
44 changes: 43 additions & 1 deletion scico/test/functional/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ class TestCheckAttrs:
# and set to True/False in the Functional subclasses.

# Generate a list of all functionals in scico.functionals that we will check
ignore = [functional.Functional, functional.ScaledFunctional, functional.SeparableFunctional]
ignore = [
functional.Functional,
functional.ScaledFunctional,
functional.SeparableFunctional,
]
to_check = []
for name, cls in functional.__dict__.items():
if isinstance(cls, type):
Expand All @@ -30,6 +34,44 @@ def test_has_prox(self, cls):
assert isinstance(cls.has_prox, bool)


class TestJit:
# Test whether functionals can be jitted.

# Generate a list of all functionals in scico.functionals that we will check
ignore = [
functional.Functional,
functional.ScaledFunctional,
functional.SeparableFunctional,
functional.BM3D,
functional.BM4D,
]
to_check = []
for name, cls in functional.__dict__.items():
if isinstance(cls, type):
if issubclass(cls, functional.Functional):
if cls not in ignore:
to_check.append(cls)

@pytest.mark.parametrize("cls", to_check)
def test_jit(self, cls):
# Only test functionals that have no required __init__ parameters.
try:
f = cls()
except TypeError:
pass
else:
v = snp.arange(4.0)
# Only test functionals that can take 1D input.
try:
u0 = f.prox(v)
except ValueError:
pass
else:
fprox = jax.jit(f.prox)
u1 = fprox(v)
assert np.allclose(u0, u1)


def test_functional_sum():
x = np.random.randn(4, 4)
f0 = functional.L1Norm()
Expand Down

0 comments on commit a03aa56

Please sign in to comment.