Skip to content

Commit

Permalink
Merge branch 'main' into brendt/docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg authored Mar 23, 2022
2 parents c92a78a + ec37e48 commit 991c8a0
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
12 changes: 10 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ SCICO Release Notes
Version 0.0.3 (unreleased)
----------------------------

• No changes yet
• Support for ``jaxlib`` versions 0.3.0 to 0.3.2 and ``jax`` versions
0.3.0 to 0.3.4.
• Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules
to ``TomographicProjector``.
• Add support for fan beam CT in ``radon_svmbir`` module.
• Add function ``linop.linop_from_function`` for constructing linear
operators from functions.
• Add support for addition of functionals.



Version 0.0.2 (2022-02-14)
Expand All @@ -21,7 +29,7 @@ Version 0.0.2 (2022-02-14)
• Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats.
• Move some functions from ``util`` and ``math`` modules to new ``array``
module.
• Bump pinned `jaxlib` and `jax` versions to 0.3.0.
• Bump pinned ``jaxlib`` and ``jax`` versions to 0.3.0.


Version 0.0.1 (2021-11-24)
Expand Down
35 changes: 30 additions & 5 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ def __repr__(self):
def __mul__(self, other):
if snp.isscalar(other) or isinstance(other, jax.core.Tracer):
return ScaledFunctional(self, other)
raise NotImplementedError(
f"Operation __mul__ not defined between {type(self)} and {type(other)}"
)
return NotImplemented

def __rmul__(self, other):
return self.__mul__(other)

def __add__(self, other):
if isinstance(other, Functional):
return FunctionalSum(self, other)
return NotImplemented

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
r"""Evaluate this functional at point :math:`\mb{x}`.
Expand Down Expand Up @@ -127,6 +130,28 @@ def grad(self, x: Union[JaxArray, BlockArray]):
return self._grad(x)


class FunctionalSum(Functional):
r"""A sum of two functionals."""

def __repr__(self):
return (
"Sum of functionals of types "
+ str(type(self.functional1))
+ " and "
+ str(type(self.functional2))
)

def __init__(self, functional1: Functional, functional2: Functional):
self.functional1 = functional1
self.functional2 = functional2
self.has_eval = functional1.has_eval and functional2.has_eval
self.has_prox = False
super().__init__()

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.functional1(x) + self.functional2(x)


class ScaledFunctional(Functional):
r"""A functional multiplied by a scalar."""

Expand Down Expand Up @@ -156,8 +181,8 @@ def prox(
factors, i.e., for functional :math:`f` and scaling factors
:math:`\alpha` and :math:`\beta`, the proximal operator with scaling
parameter :math:`\alpha` of scaled functional :math:`\beta f` is
the proximal operator with scaling parameter :math:`\alpha \beta` of
functional :math:`f`,
the proximal operator with scaling parameter :math:`\alpha \beta`
of functional :math:`f`,
.. math::
\mathrm{prox}_{\alpha (\beta f)}(\mb{v}) =
Expand Down
2 changes: 1 addition & 1 deletion scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def minimize(
Wrapper around :func:`scipy.optimize.minimize`. This function differs
from :func:`scipy.optimize.minimize` in three ways:
- The `jac` options of :func:`scipy.optimize.minimize` are not
- The `jac` options of :func:`scipy.optimize.minimize` are not
supported. The gradient is calculated using `jax.grad`.
- Functions mapping from N-dimensional arrays -> float are
supported.
Expand Down
10 changes: 10 additions & 0 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,16 @@ def test_has_prox(self, cls):
assert isinstance(cls.has_prox, bool)


def test_functional_sum():
x = np.random.randn(4, 4)
f0 = functional.L1Norm()
f1 = 2.0 * functional.L2Norm()
f = f0 + f1
assert f(x) == f0(x) + f1(x)
with pytest.raises(TypeError):
f = f0 + 2.0


def test_scalar_vmap():
x = np.random.randn(4, 4)
f = functional.L1Norm()
Expand Down

0 comments on commit 991c8a0

Please sign in to comment.