From ec37e48c3fa9d856b188d02226f4bd044b9de6be Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 23 Mar 2022 09:56:14 -0600 Subject: [PATCH] Add support for addition of functionals (#258) * Add sum operator for functionals * Trivial edit * Minor docstring edit * Add test * Update change summary * Replace raising NotImplementedError with returning NotImplemented --- CHANGES.rst | 12 +++++++++-- scico/functional/_functional.py | 35 ++++++++++++++++++++++++++++----- scico/solver.py | 2 +- scico/test/test_functional.py | 10 ++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 157ec0b13..8dcadcfc7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,15 @@ SCICO Release Notes Version 0.0.3 (unreleased) ---------------------------- -• No changes yet +• Support for ``jaxlib`` versions 0.3.0 to 0.3.2 and ``jax`` versions + 0.3.0 to 0.3.4. +• Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules + to ``TomographicProjector``. +• Add support for fan beam CT in ``radon_svmbir`` module. +• Add function ``linop.linop_from_function`` for constructing linear + operators from functions. +• Add support for addition of functionals. + Version 0.0.2 (2022-02-14) @@ -21,7 +29,7 @@ Version 0.0.2 (2022-02-14) • Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats. • Move some functions from ``util`` and ``math`` modules to new ``array`` module. -• Bump pinned `jaxlib` and `jax` versions to 0.3.0. +• Bump pinned ``jaxlib`` and ``jax`` versions to 0.3.0. Version 0.0.1 (2021-11-24) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index fb86f4519..408946805 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -45,13 +45,16 @@ def __repr__(self): def __mul__(self, other): if snp.isscalar(other) or isinstance(other, jax.core.Tracer): return ScaledFunctional(self, other) - raise NotImplementedError( - f"Operation __mul__ not defined between {type(self)} and {type(other)}" - ) + return NotImplemented def __rmul__(self, other): return self.__mul__(other) + def __add__(self, other): + if isinstance(other, Functional): + return FunctionalSum(self, other) + return NotImplemented + def __call__(self, x: Union[JaxArray, BlockArray]) -> float: r"""Evaluate this functional at point :math:`\mb{x}`. @@ -127,6 +130,28 @@ def grad(self, x: Union[JaxArray, BlockArray]): return self._grad(x) +class FunctionalSum(Functional): + r"""A sum of two functionals.""" + + def __repr__(self): + return ( + "Sum of functionals of types " + + str(type(self.functional1)) + + " and " + + str(type(self.functional2)) + ) + + def __init__(self, functional1: Functional, functional2: Functional): + self.functional1 = functional1 + self.functional2 = functional2 + self.has_eval = functional1.has_eval and functional2.has_eval + self.has_prox = False + super().__init__() + + def __call__(self, x: Union[JaxArray, BlockArray]) -> float: + return self.functional1(x) + self.functional2(x) + + class ScaledFunctional(Functional): r"""A functional multiplied by a scalar.""" @@ -156,8 +181,8 @@ def prox( factors, i.e., for functional :math:`f` and scaling factors :math:`\alpha` and :math:`\beta`, the proximal operator with scaling parameter :math:`\alpha` of scaled functional :math:`\beta f` is - the proximal operator with scaling parameter :math:`\alpha \beta` of - functional :math:`f`, + the proximal operator with scaling parameter :math:`\alpha \beta` + of functional :math:`f`, .. math:: \mathrm{prox}_{\alpha (\beta f)}(\mb{v}) = diff --git a/scico/solver.py b/scico/solver.py index 60a431de4..ba42d308d 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -183,7 +183,7 @@ def minimize( Wrapper around :func:`scipy.optimize.minimize`. This function differs from :func:`scipy.optimize.minimize` in three ways: - - The `jac` options of :func:`scipy.optimize.minimize` are not + - The `jac` options of :func:`scipy.optimize.minimize` are not supported. The gradient is calculated using `jax.grad`. - Functions mapping from N-dimensional arrays -> float are supported. diff --git a/scico/test/test_functional.py b/scico/test/test_functional.py index f95dae06d..21f86a7a0 100644 --- a/scico/test/test_functional.py +++ b/scico/test/test_functional.py @@ -260,6 +260,16 @@ def test_has_prox(self, cls): assert isinstance(cls.has_prox, bool) +def test_functional_sum(): + x = np.random.randn(4, 4) + f0 = functional.L1Norm() + f1 = 2.0 * functional.L2Norm() + f = f0 + f1 + assert f(x) == f0(x) + f1(x) + with pytest.raises(TypeError): + f = f0 + 2.0 + + def test_scalar_vmap(): x = np.random.randn(4, 4) f = functional.L1Norm()