diff --git a/.github/codecov.yml b/.github/codecov.yml
index 9de3e3f04..2134d07fc 100644
--- a/.github/codecov.yml
+++ b/.github/codecov.yml
@@ -1,3 +1,11 @@
coverage:
+ precision: 2
+ round: nearest
+ range: "80...100"
+
status:
+ project:
+ default:
+ target: auto
+ threshold: 0.05%
patch: false
diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_admm.py
index ed500f60e..32f3eafde 100644
--- a/examples/scripts/sparsecode_admm.py
+++ b/examples/scripts/sparsecode_admm.py
@@ -24,7 +24,7 @@
import jax
from scico import functional, linop, loss, plot
-from scico.optimize.admm import ADMM, LinearSubproblemSolver
+from scico.optimize.admm import ADMM, MatrixSubproblemSolver
from scico.util import device_info
"""
@@ -67,7 +67,7 @@
rho_list=rho_list,
x0=A.adj(y),
maxiter=maxiter,
- subproblem_solver=LinearSubproblemSolver(),
+ subproblem_solver=MatrixSubproblemSolver(),
itstat_options={"display": True, "period": 10},
)
diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py
index 30cd21949..0e9251f57 100644
--- a/scico/linop/_diag.py
+++ b/scico/linop/_diag.py
@@ -42,9 +42,9 @@ def __init__(
r"""
Args:
diagonal: Diagonal elements of this :class:`LinearOperator`.
- input_shape: Shape of input array. By default, equal to
+ input_shape: Shape of input array. By default, equal to
`diagonal.shape`, but may also be set to a shape that is
- broadcast-compatiable with `diagonal.shape`.
+ broadcast-compatible with `diagonal.shape`.
input_dtype: `dtype` of input argument. The default,
``None``, means `diagonal.dtype`.
"""
@@ -64,7 +64,7 @@ def __init__(
elif isinstance(diagonal, BlockArray):
raise ValueError("Parameter diagonal was a BlockArray but input_shape was not nested.")
else:
- raise ValueError("Parameter diagonal was a not BlockArray but input_shape was nested.")
+ raise ValueError("Parameter diagonal was not a BlockArray but input_shape was nested.")
super().__init__(
input_shape=input_shape,
@@ -77,6 +77,29 @@ def __init__(
def _eval(self, x):
return x * self.diagonal
+ @property
+ def T(self) -> Diagonal:
+ """Transpose of this :class:`Diagonal`."""
+ return self
+
+ def conj(self) -> Diagonal:
+ """Complex conjugate of this :class:`Diagonal`."""
+ return Diagonal(diagonal=self.diagonal.conj())
+
+ @property
+ def H(self) -> Diagonal:
+ """Hermitian transpose of this :class:`Diagonal`."""
+ return self.conj()
+
+ @property
+ def gram_op(self) -> Diagonal:
+ """Gram operator of this :class:`Diagonal`.
+
+ Return a new :class:`Diagonal` :code:`G` such that
+ :code:`G(x) = A.adj(A(x)))`.
+ """
+ return Diagonal(diagonal=self.diagonal.conj() * self.diagonal)
+
@partial(_wrap_add_sub, op=operator.add)
def __add__(self, other):
if self.diagonal.shape == other.diagonal.shape:
@@ -101,6 +124,41 @@ def __rmul__(self, scalar):
def __truediv__(self, scalar):
return Diagonal(diagonal=self.diagonal / scalar)
+ def __matmul__(self, other):
+ # self @ other
+ if isinstance(other, Diagonal):
+ if self.shape == other.shape:
+ return Diagonal(diagonal=self.diagonal * other.diagonal)
+
+ raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.")
+
+ else:
+ return self(other)
+
+ def norm(self, ord=None): # pylint: disable=W0622
+ """Compute the matrix norm of the diagonal operator.
+
+ Valid values of `ord` and the corresponding norm definition
+ are those listed under "norm for matrices" in the
+ :func:`scico.numpy.linalg.norm` documentation.
+ """
+ ordfunc = {
+ "fro": lambda x: snp.linalg.norm(x),
+ "nuc": lambda x: snp.sum(snp.abs(x)),
+ -snp.inf: lambda x: snp.abs(x).min(),
+ snp.inf: lambda x: snp.abs(x).max(),
+ }
+ mord = ord
+ if mord is None:
+ mord = "fro"
+ elif mord in (-1, -2):
+ mord = -snp.inf
+ elif mord in (1, 2):
+ mord = snp.inf
+ if mord not in ordfunc:
+ raise ValueError(f"Invalid value {ord} for parameter ord.")
+ return ordfunc[mord](self.diagonal)
+
class Identity(Diagonal):
"""Identity operator."""
diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py
index e2d35e3c7..0f6f21daa 100644
--- a/scico/linop/_matrix.py
+++ b/scico/linop/_matrix.py
@@ -172,7 +172,7 @@ def __mul__(self, other):
raise TypeError(f"Operation __mul__ not defined between {type(self)} and {type(other)}.")
def __rmul__(self, other):
- # Multiplication is commutative
+ # multiplication is commutative
return self * other
def __truediv__(self, other):
@@ -255,6 +255,6 @@ def gram_op(self):
def norm(self, ord=None, axis=None, keepdims=False): # pylint: disable=W0622
"""Compute the norm of the dense matrix `self.A`.
- Call :func:`scico.numpy.norm` on the dense matrix `self.A`.
+ Call :func:`scico.numpy.linalg.norm` on the dense matrix `self.A`.
"""
return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims)
diff --git a/scico/optimize/_admm.py b/scico/optimize/_admm.py
index ec4fa8edf..55862f0cb 100644
--- a/scico/optimize/_admm.py
+++ b/scico/optimize/_admm.py
@@ -25,6 +25,7 @@
G0BlockCircularConvolveSolver,
GenericSubproblemSolver,
LinearSubproblemSolver,
+ MatrixSubproblemSolver,
SubproblemSolver,
)
from ._common import Optimizer
@@ -189,7 +190,7 @@ def _itstat_extra_fields(self):
)
elif (
type(self.subproblem_solver)
- in [FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]
+ in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]
and self.subproblem_solver.check_solve
):
itstat_fields.update({"Slv Res": "%9.3e"})
diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py
index 5d670d7e6..7a0ceb0b2 100644
--- a/scico/optimize/_admmaux.py
+++ b/scico/optimize/_admmaux.py
@@ -23,14 +23,15 @@
from scico.linop import (
CircularConvolve,
ComposedLinearOperator,
+ Diagonal,
Identity,
LinearOperator,
- Sum,
+ MatrixOperator,
)
from scico.loss import SquaredL2Loss
-from scico.metric import rel_res
from scico.numpy import Array, BlockArray
from scico.numpy.util import ensure_on_device, is_real_dtype
+from scico.solver import ATADSolver, ConvATADSolver
from scico.solver import cg as scico_cg
from scico.solver import minimize
@@ -126,8 +127,8 @@ class LinearSubproblemSolver(SubproblemSolver):
for the case where :code:`f` is an :math:`\ell_2` or weighted
:math:`\ell_2` norm, and :code:`f.A` is a linear operator, so that
the subproblem involves solving a linear equation. This requires that
- `f.functional` be an instance of :class:`.SquaredL2Loss` and for
- the forward operator :code:`f.A` to be an instance of
+ :code:`f.functional` be an instance of :class:`.SquaredL2Loss` and
+ for the forward operator :code:`f.A` to be an instance of
:class:`.LinearOperator`.
The :math:`\mb{x}`-update step is
@@ -139,7 +140,7 @@ class LinearSubproblemSolver(SubproblemSolver):
\norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,
where :math:`W` a weighting :class:`.Diagonal` operator
- or an :class:`.Identity` operator (i.e. no weighting).
+ or an :class:`.Identity` operator (i.e., no weighting).
This update step reduces to the solution of the linear system
.. math::
@@ -200,19 +201,19 @@ def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str
def internal_init(self, admm: soa.ADMM):
if admm.f is not None:
if not isinstance(admm.f, SquaredL2Loss):
- raise ValueError(
+ raise TypeError(
"LinearSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; "
f"got {type(admm.f)}."
)
if not isinstance(admm.f.A, LinearOperator):
- raise ValueError(
+ raise TypeError(
"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; "
f"got {type(admm.f.A)}."
)
super().internal_init(admm)
- # Set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve
+ # Set lhs_op = \sum_i rho_i * Ci.H @ Ci
# Use reduce as the initialization of this sum is messy otherwise
lhs_op = reduce(
lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
@@ -266,6 +267,110 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
return x
+class MatrixSubproblemSolver(LinearSubproblemSolver):
+ r"""Solver for quadratic functionals involving matrix operators.
+
+ Solver for the case in which :math:`f` is a quadratic function of
+ :math:`\mb{x}`, and :math:`A` and all the :math:`C_i` are diagonal
+ or matrix operators. It is a specialization of
+ :class:`.LinearSubproblemSolver`.
+
+ As for :class:`.LinearSubproblemSolver`, the :math:`\mb{x}`-update
+ step is
+
+ .. math::
+
+ \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2}
+ \norm{\mb{y} - A \mb{x}}_W^2 + \sum_i \frac{\rho_i}{2}
+ \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,
+
+ where :math:`W` is a weighting :class:`.Diagonal` operator
+ or an :class:`.Identity` operator (i.e., no weighting).
+ This update step reduces to the solution of the linear system
+
+ .. math::
+
+ \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right)
+ \mb{x}^{(k+1)} = \;
+ A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i -
+ \mb{u}^{(k)}_i) \;,
+
+ which is solved by factorization of the left hand side of the
+ equation, using :class:`.ATADSolver`.
+
+
+ Attributes:
+ admm (:class:`.ADMM`): ADMM solver object to which the solver is
+ attached.
+ solve_kwargs (dict): Dictionary of arguments for solver
+ :class:`.ATADSolver` initialization.
+ """
+
+ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None):
+ """Initialize a :class:`MatrixSubproblemSolver` object.
+
+ Args:
+ check_solve: If ``True``, compute solver accuracy after each
+ solve.
+ solve_kwargs: Dictionary of arguments for solver
+ :class:`.ATADSolver` initialization.
+ """
+ self.check_solve = check_solve
+ default_solve_kwargs = {"cho_factor": False}
+ if solve_kwargs:
+ default_solve_kwargs.update(solve_kwargs)
+ self.solve_kwargs = default_solve_kwargs
+
+ def internal_init(self, admm: soa.ADMM):
+ if admm.f is not None:
+ if not isinstance(admm.f, SquaredL2Loss):
+ raise TypeError(
+ "MatrixSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; "
+ f"got {type(admm.f)}."
+ )
+ if not isinstance(admm.f.A, (Diagonal, MatrixOperator)):
+ raise TypeError(
+ "MatrixSubproblemSolver requires f.A to be a Diagonal or MatrixOperator; "
+ f"got {type(admm.f.A)}."
+ )
+ for i, Ci in enumerate(admm.C_list):
+ if not isinstance(Ci, (Diagonal, MatrixOperator)):
+ raise TypeError(
+ "MatrixSubproblemSolver requires C[{i}] to be a Diagonal or MatrixOperator; "
+ f"got {type(Ci)}."
+ )
+
+ super().internal_init(admm)
+
+ if admm.f is None:
+ A = snp.zeros(admm.C_list[0].input_shape[0], dtype=admm.C_list[0].input_dtype)
+ W = None
+ else:
+ A = admm.f.A
+ W = 2.0 * self.admm.f.scale * admm.f.W # type: ignore
+
+ Csum = reduce(
+ lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
+ )
+ self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs)
+
+ def solve(self, x0: Array) -> Array:
+ """Solve the ADMM step.
+
+ Args:
+ x0: Initial value (ignored).
+
+ Returns:
+ Computed solution.
+ """
+ rhs = self.compute_rhs()
+ x = self.solver.solve(rhs)
+ if self.check_solve:
+ self.accuracy = self.solver.accuracy(x, rhs)
+
+ return x
+
+
class CircularConvolveSolver(LinearSubproblemSolver):
r"""Solver for linear operators diagonalized in the DFT domain.
@@ -293,12 +398,12 @@ def __init__(self):
def internal_init(self, admm: soa.ADMM):
if admm.f is not None:
if not isinstance(admm.f, SquaredL2Loss):
- raise ValueError(
+ raise TypeError(
"CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; "
f"got {type(admm.f)}."
)
if not isinstance(admm.f.A, (CircularConvolve, Identity)):
- raise ValueError(
+ raise TypeError(
"CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve "
f"or scico.linop.Identity; got {type(admm.f.A)}."
)
@@ -454,35 +559,19 @@ def internal_init(self, admm: soa.ADMM):
raise ValueError("FBlockCircularConvolveSolver does not allow f to be None.")
else:
if not isinstance(admm.f, SquaredL2Loss):
- raise ValueError(
+ raise TypeError(
"FBlockCircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; "
f"got {type(admm.f)}."
)
if not isinstance(admm.f.A, ComposedLinearOperator):
- raise ValueError(
+ raise TypeError(
"FBlockCircularConvolveSolver requires f.A to be a composition of Sum "
f"and CircularConvolve linear operators; got {type(admm.f.A)}."
)
- if not isinstance(admm.f.A.A, Sum) or not isinstance(admm.f.A.B, CircularConvolve):
- raise ValueError(
- "FBlockCircularConvolveSolver requires f.A to be a composition of Sum "
- "and CircularConvolve linear operators; got a composition of "
- f"{type(admm.f.A.A)} and {type(admm.f.A.B)}."
- )
- self.sum_axis = admm.f.A.A.kwargs["axis"]
- if not isinstance(self.sum_axis, int):
- raise ValueError(
- "FBlockCircularConvolveSolver requires the Sum operator component "
- "of f.A to sum over a single axis of its input."
- )
-
super().internal_init(admm)
assert isinstance(self.admm.f, SquaredL2Loss)
assert isinstance(self.admm.f.A, ComposedLinearOperator)
- assert isinstance(self.admm.f.A.B, CircularConvolve)
-
- self.real_result = is_real_dtype(admm.C_list[0].input_dtype)
# All of the C operators are assumed to be linear and shift invariant
# but this is not checked.
@@ -490,12 +579,8 @@ def internal_init(self, admm: soa.ADMM):
rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)
for rho, C in zip(admm.rho_list, admm.C_list)
]
- self.D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale)
- A = self.admm.f.A.B
- self.AHEinv = A.h_dft.conj() / (
- 1.0
- + snp.sum(A.h_dft * (A.h_dft.conj() / self.D.h_dft), axis=self.sum_axis, keepdims=True)
- )
+ D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale)
+ self.solver = ConvATADSolver(self.admm.f.A, D)
def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
"""Solve the ADMM step.
@@ -507,27 +592,11 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
Computed solution.
"""
assert isinstance(self.admm.f, SquaredL2Loss)
- assert isinstance(self.admm.f.A, ComposedLinearOperator)
- assert isinstance(self.admm.f.A.B, CircularConvolve)
rhs = self.compute_rhs() / (2.0 * self.admm.f.scale)
- fft_axes = self.admm.f.A.B.x_fft_axes
- rhs_dft = snp.fft.fftn(rhs, axes=fft_axes)
- A = self.admm.f.A.B
- x_dft = (
- rhs_dft
- - (
- self.AHEinv
- * (snp.sum(A.h_dft * rhs_dft / self.D.h_dft, axis=self.sum_axis, keepdims=True))
- )
- ) / self.D.h_dft
- x = snp.fft.ifftn(x_dft, axes=fft_axes)
- if self.real_result:
- x = x.real
-
+ x = self.solver.solve(rhs)
if self.check_solve:
- lhs = self.admm.f.A.gram_op(x) + self.D(x)
- self.accuracy = rel_res(lhs, rhs)
+ self.accuracy = self.solver.accuracy(x, rhs)
return x
@@ -537,7 +606,7 @@ class G0BlockCircularConvolveSolver(SubproblemSolver):
domain.
Specialization of :class:`.LinearSubproblemSolver` for the case
- where :math:`f = 0` (i.e. :code:`f` is a :class:`.ZeroFunctional`),
+ where :math:`f = 0` (i.e, :code:`f` is a :class:`.ZeroFunctional`),
:math:`g_1` is an instance of :class:`.SquaredL2Loss`, :math:`C_1`
is a composition of a :class:`.Sum` operator an a
:class:`.CircularConvolve` operator. The former must sum over the
@@ -661,37 +730,20 @@ def internal_init(self, admm: soa.ADMM):
"G0BlockCircularConvolveSolver requires f to be None or a ZeroFunctional"
)
if not isinstance(admm.g_list[0], SquaredL2Loss):
- raise ValueError(
+ raise TypeError(
"G0BlockCircularConvolveSolver requires g_1 to be a scico.loss.SquaredL2Loss; "
f"got {type(admm.g_list[0])}."
)
if not isinstance(admm.C_list[0], ComposedLinearOperator):
- raise ValueError(
+ raise TypeError(
"G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum "
f"and CircularConvolve linear operators; got {type(admm.C_list[0])}."
)
- if not isinstance(admm.C_list[0].A, Sum) or not isinstance(
- admm.C_list[0].B, CircularConvolve
- ):
- raise ValueError(
- "G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum "
- "and CircularConvolve linear operators; got a composition of "
- f"{type(admm.C_list[0].A)} and {type(admm.C_list[0].B)}."
- )
- self.sum_axis = admm.C_list[0].A.kwargs["axis"]
- if not isinstance(self.sum_axis, int):
- raise ValueError(
- "G0BlockCircularConvolveSolver requires the Sum operator component "
- "of C_1 to sum over a single axis of its input."
- )
super().internal_init(admm)
assert isinstance(self.admm.g_list[0], SquaredL2Loss)
assert isinstance(self.admm.C_list[0], ComposedLinearOperator)
- assert isinstance(self.admm.C_list[0].B, CircularConvolve)
-
- self.real_result = is_real_dtype(admm.C_list[0].input_dtype)
# All of the C operators are assumed to be linear and shift invariant
# but this is not checked.
@@ -699,14 +751,10 @@ def internal_init(self, admm: soa.ADMM):
rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)
for rho, C in zip(admm.rho_list[1:], admm.C_list[1:])
]
- self.D = reduce(lambda a, b: a + b, c_gram_list) / (
+ D = reduce(lambda a, b: a + b, c_gram_list) / (
2.0 * self.admm.g_list[0].scale * admm.rho_list[0]
)
- A = self.admm.C_list[0].B
- self.AHEinv = A.h_dft.conj() / (
- 1.0
- + snp.sum(A.h_dft * (A.h_dft.conj() / self.D.h_dft), axis=self.sum_axis, keepdims=True)
- )
+ self.solver = ConvATADSolver(self.admm.C_list[0], D)
def compute_rhs(self) -> Union[Array, BlockArray]:
r"""Compute the right hand side of the linear equation to be solved.
@@ -720,7 +768,7 @@ def compute_rhs(self) -> Union[Array, BlockArray]:
( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;.
Returns:
- Computed solution.
+ Right hand side of the linear equation.
"""
assert isinstance(self.admm.g_list[0], SquaredL2Loss)
@@ -746,26 +794,10 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
Computed solution.
"""
assert isinstance(self.admm.g_list[0], SquaredL2Loss)
- assert isinstance(self.admm.C_list[0], ComposedLinearOperator)
- assert isinstance(self.admm.C_list[0].B, CircularConvolve)
rhs = self.compute_rhs() / (2.0 * self.admm.g_list[0].scale * self.admm.rho_list[0])
- fft_axes = self.admm.C_list[0].B.x_fft_axes
- rhs_dft = snp.fft.fftn(rhs, axes=fft_axes)
- A = self.admm.C_list[0].B
- x_dft = (
- rhs_dft
- - (
- self.AHEinv
- * (snp.sum(A.h_dft * rhs_dft / self.D.h_dft, axis=self.sum_axis, keepdims=True))
- )
- ) / self.D.h_dft
- x = snp.fft.ifftn(x_dft, axes=fft_axes)
- if self.real_result:
- x = x.real
-
+ x = self.solver.solve(rhs)
if self.check_solve:
- lhs = self.admm.C_list[0].gram_op(x) + self.D(x)
- self.accuracy = rel_res(lhs, rhs)
+ self.accuracy = self.solver.accuracy(x, rhs)
return x
diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py
index ee9a126ed..7f64cc2d6 100644
--- a/scico/optimize/admm.py
+++ b/scico/optimize/admm.py
@@ -14,6 +14,7 @@
SubproblemSolver,
GenericSubproblemSolver,
LinearSubproblemSolver,
+ MatrixSubproblemSolver,
CircularConvolveSolver,
FBlockCircularConvolveSolver,
G0BlockCircularConvolveSolver,
@@ -24,6 +25,7 @@
"SubproblemSolver",
"GenericSubproblemSolver",
"LinearSubproblemSolver",
+ "MatrixSubproblemSolver",
"CircularConvolveSolver",
"FBlockCircularConvolveSolver",
"G0BlockCircularConvolveSolver",
diff --git a/scico/solver.py b/scico/solver.py
index d69c47247..5f0994246 100644
--- a/scico/solver.py
+++ b/scico/solver.py
@@ -5,24 +5,14 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.
-"""SciPy optimization algorithms.
-
-.. raw:: html
-
-
-
-This module provides scico interface wrappers for functions
+"""Solver and optimization algorithms.
+
+This module provides a number of functions for solving linear systems and
+optimization problems, some of which are used as subproblem solvers
+within the iterations of the proximal algorithms in the
+:mod:`scico.optimize` subpackage.
+
+This module also provides scico interface wrappers for functions
from :mod:`scipy.optimize` since jax directly implements only a very
limited subset of these functions (there is limited, experimental support
for `L-BFGS-B `_), but only CG
@@ -39,8 +29,8 @@
The wrapper also JIT compiles the function and gradient evaluations.
-The functions provided in this module have a number of advantages and
-disadvantages with respect to those in :mod:`jax.scipy.optimize`:
+These wrapper functions have a number of advantages and disadvantages
+with respect to those in :mod:`jax.scipy.optimize`:
- This module provides many more algorithms than
:mod:`jax.scipy.optimize`.
@@ -52,7 +42,7 @@
- The solvers in this module can't be JIT compiled, and gradients cannot
be taken through them.
-In the future, this module may be replaced with a dependency on
+In the future, these wrapper functions may be replaced with a dependency on
`JAXopt `__.
"""
@@ -64,10 +54,20 @@
import jax
import jax.experimental.host_callback as hcb
+import jax.scipy.linalg as jsl
-import scico.linop
import scico.numpy as snp
+from scico.linop import (
+ CircularConvolve,
+ ComposedLinearOperator,
+ Diagonal,
+ LinearOperator,
+ MatrixOperator,
+ Sum,
+)
+from scico.metric import rel_res
from scico.numpy import Array, BlockArray
+from scico.numpy.util import is_real_dtype
from scico.typing import BlockShape, DType, Shape
from scipy import optimize as spopt
@@ -336,7 +336,7 @@ def cg(
- **info**: Dictionary containing diagnostic information.
"""
if x0 is None:
- if isinstance(A, scico.linop.LinearOperator):
+ if isinstance(A, LinearOperator):
x0 = snp.zeros(A.input_shape, b.dtype)
else:
raise ValueError("Parameter x0 must be specified if A is not a LinearOperator")
@@ -418,11 +418,11 @@ def lstsq(
- **x** : Solution array.
- **info**: Dictionary containing diagnostic information.
"""
- if isinstance(A, scico.linop.LinearOperator):
+ if isinstance(A, LinearOperator):
Aop = A
else:
assert x0 is not None
- Aop = scico.linop.LinearOperator(
+ Aop = LinearOperator(
input_shape=x0.shape,
output_shape=b.shape,
eval_fn=A,
@@ -577,3 +577,371 @@ def golden(
else:
r = x
return r
+
+
+class ATADSolver:
+ r"""Solver for linear system involving a symmetric product plus a diagonal.
+
+ Solve a linear system of the form
+
+ .. math::
+
+ (A^T W A + D) \mb{x} = \mb{b}
+
+ or
+
+ .. math::
+
+ (A^T W A + D) X = B \;,
+
+ where :math:`A \in \mbb{R}^{M \times N}`,
+ :math:`W \in \mbb{R}^{M \times M}` and
+ :math:`D \in \mbb{R}^{N \times N}`. The solution is computed by
+ factorization of matrix :math:`A^T W A + D` and solution via Gaussian
+ elimination. If :math:`D` is diagonal and :math:`N < M` (i.e.
+ :math:`A W A^T` is smaller than :math:`A^T W A`), then
+ :math:`A W A^T + D` is factorized and the original problem is solved
+ via the Woodbury matrix identity
+
+ .. math::
+
+ (E + U C V)^{-1} = E^{-1} - E^{-1} U (C^{-1} + V E^{-1} U)^{-1}
+ V E^{-1} \;.
+
+ Setting
+
+ .. math::
+
+ E &= D \\
+ U &= A^T \\
+ C &= W \\
+ V &= A
+
+ we have
+
+ .. math::
+
+ (D + A^T W A)^{-1} = D^{-1} - D^{-1} A^T (W^{-1} + A D^{-1} A^T)^{-1} A
+ D^{-1}
+
+ which can be simplified to
+
+ .. math::
+
+ (D + A^T W A)^{-1} = D^{-1} (I - A^T G^{-1} A D^{-1})
+
+ by defining :math:`G = W^{-1} + A D^{-1} A^T`. We therefore have that
+
+ .. math::
+
+ \mb{x} = (D + A^T W A)^{-1} \mb{b} = D^{-1} (I - A^T G^{-1} A
+ D^{-1}) \mb{b} \;.
+
+ If we have a Cholesky factorization of :math:`G`, e.g.
+ :math:`G = L L^T`, we can define
+
+ .. math::
+
+ \mb{w} = G^{-1} A D^{-1} \mb{b}
+
+ so that
+
+ .. math::
+
+ G \mb{w} &= A D^{-1} \mb{b} \\
+ L L^T \mb{w} &= A D^{-1} \mb{b} \;.
+
+ The Cholesky factorization can be exploited by solving for
+ :math:`\mb{z}` in
+
+ .. math::
+
+ L \mb{z} = A D^{-1} \mb{b}
+
+ and then for :math:`\mb{w}` in
+
+ .. math::
+
+ L^T \mb{w} = \mb{z} \;,
+
+ so that
+
+ .. math::
+
+ \mb{x} = D^{-1} \mb{b} - D^{-1} A^T \mb{w} \;.
+
+ (Functions :func:`~jax.scipy.linalg.cho_solve` and
+ :func:`~jax.scipy.linalg.lu_solve` allow direct solution for
+ :math:`\mb{w}` without the two-step procedure described here.) A
+ Cholesky factorization should only be used when :math:`G` is
+ positive-definite (e.g. :math:`D` is diagonal and positive); if not,
+ an LU factorization should be used.
+
+ Complex-valued problems are also supported, in which case the
+ transpose :math:`\cdot^T` in the equations above should be taken to
+ represent the conjugate transpose.
+
+ To solve problems directly involving a matrix of the form
+ :math:`A W A^T + D`, initialize with :code:`A.T` (or
+ :code:`A.T.conj()` for complex problems) instead of :code:`A`.
+ """
+
+ def __init__(
+ self,
+ A: Union[MatrixOperator, Array],
+ D: Union[MatrixOperator, Diagonal, Array],
+ W: Optional[Union[Diagonal, Array]] = None,
+ cho_factor: bool = False,
+ lower: bool = False,
+ check_finite: bool = True,
+ ):
+ r"""
+ Args:
+ A: Matrix :math:`A`.
+ D: Matrix :math:`D`.
+ W: Matrix :math:`W`.
+ cho_factor: Flag indicating whether to use Cholesky
+ (``True``) or LU (``False``) factorization.
+ lower: Flag indicating whether lower (``True``) or upper
+ (``False``) triangular factorization should be computed.
+ Only relevant to Cholesky factorization.
+ check_finite: Flag indicating whether the input array should
+ be checked for ``Inf`` and ``NaN`` values.
+ """
+ if isinstance(A, MatrixOperator):
+ A = A.to_array()
+ if isinstance(D, MatrixOperator):
+ D = D.to_array()
+ elif isinstance(D, Diagonal):
+ D = D.diagonal
+ if W is None:
+ W = snp.ones(A.shape[0], dtype=A.dtype)
+ elif isinstance(W, Diagonal):
+ W = W.diagonal
+ self.A = A
+ self.D = D
+ self.W = W
+ self.cho_factor = cho_factor
+ self.lower = lower
+ self.check_finite = check_finite
+
+ assert isinstance(W, Array)
+ N, M = A.shape
+ if N < M and D.ndim == 1:
+ G = snp.diag(1.0 / W) + A @ (A.T.conj() / D[:, snp.newaxis])
+ else:
+ if D.ndim == 1:
+ G = A.T.conj() @ (W[:, snp.newaxis] * A) + snp.diag(D)
+ else:
+ G = A.T.conj() @ (W[:, snp.newaxis] * A) + D
+
+ if cho_factor:
+ c, lower = jsl.cho_factor(G, lower=lower, check_finite=check_finite)
+ self.factor = (c, lower)
+ else:
+ lu, piv = jsl.lu_factor(G, check_finite=check_finite)
+ self.factor = (lu, piv)
+
+ def solve(self, b: Array, check_finite: Optional[bool] = None) -> Array:
+ r"""Solve the linear system.
+
+ Solve the linear system with right hand side :math:`\mb{b}` (`b`
+ is a vector) or :math:`B` (`b` is a 2d array).
+
+ Args:
+ b: Vector :math:`\mathbf{b}` or matrix :math:`B`.
+ check_finite: Flag indicating whether the input array should
+ be checked for ``Inf`` and ``NaN`` values. If ``None``,
+ use the value selected on initialization.
+
+ Returns:
+ Solution to the linear system.
+ """
+ if check_finite is None:
+ check_finite = self.check_finite
+ if self.cho_factor:
+ fact_solve = lambda x: jsl.cho_solve(self.factor, x, check_finite=check_finite)
+ else:
+ fact_solve = lambda x: jsl.lu_solve(self.factor, x, trans=0, check_finite=check_finite)
+
+ if b.ndim == 1:
+ D = self.D
+ else:
+ D = self.D[:, snp.newaxis]
+ N, M = self.A.shape
+ if N < M and self.D.ndim == 1:
+ w = fact_solve(self.A @ (b / D))
+ x = (b - (self.A.T.conj() @ w)) / D
+ else:
+ x = fact_solve(b)
+
+ return x
+
+ def accuracy(self, x: Array, b: Array) -> float:
+ r"""Compute solution relative residual.
+
+ Args:
+ x: Array :math:`\mathbf{x}` (solution).
+ b: Array :math:`\mathbf{b}` (right hand side of linear system).
+
+ Returns:
+ Relative residual of solution.
+ """
+ if b.ndim == 1:
+ D = self.D
+ else:
+ D = self.D[:, snp.newaxis]
+ assert isinstance(self.W, Array)
+ return rel_res(self.A.T.conj() @ (self.W[:, snp.newaxis] * self.A) @ x + D * x, b)
+
+
+class ConvATADSolver:
+ r"""Solver for sum of convolutions plus diagonal linear system.
+
+ Solve a linear system of the form
+
+ .. math::
+
+ (A^H A + D) \mb{x} = \mb{b}
+
+ where :math:`A` is a block-row operator with circulant blocks, i.e. it
+ can be written as
+
+ .. math::
+
+ A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K}
+ \end{array} \right) \;,
+
+ where all of the :math:`A_k` are circular convolution operators, and
+ :math:`D` is a circular convolution operator. This problem is most
+ easily solved in the DFT transform domain, where the circular
+ convolutions become diagonal operators. Denoting the frequency-domain
+ versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is
+ the frequency-domain version of :math:`\mb{x}`), the the problem can
+ be written as
+
+ .. math::
+
+ (\hat{A}^H \hat{A} + \hat{D}) \hat{\mb{x}} = \hat{\mb{b}} \;,
+
+ where
+
+ .. math::
+
+ \hat{A} = \left( \begin{array}{cccc} \hat{A}_1 & \hat{A}_2 &
+ \ldots & \hat{A}_{K} \end{array} \right) \;,
+
+ and :math:`\hat{D}` and all the :math:`\hat{A}_k` are diagonal
+ operators.
+
+ This linear equation is computational expensive to solve because
+ the left hand side includes the term :math:`\hat{A}^H \hat{A}`,
+ which corresponds to the outer product of :math:`\hat{A}^H`
+ and :math:`\hat{A}`. A computationally efficient solution is possible,
+ however, by exploiting the Woodbury matrix identity
+ :cite:`wohlberg-2014-efficient`
+
+ .. math::
+
+ (B + U C V)^{-1} = B^{-1} - B^{-1} U (C^{-1} + V B^{-1} U)^{-1}
+ V B^{-1} \;.
+
+ Setting
+
+ .. math::
+
+ B &= \hat{D} \\
+ U &= \hat{A}^H \\
+ C &= I \\
+ V &= \hat{A}
+
+ we have
+
+ .. math::
+
+ (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} - \hat{D}^{-1}
+ \hat{A}^H (I + \hat{A} \hat{D}^{-1} \hat{A}^H)^{-1} \hat{A}
+ \hat{D}^{-1}
+
+ which can be simplified to
+
+ .. math::
+
+ (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} (I - \hat{A}^H
+ \hat{E}^{-1} \hat{A} \hat{D}^{-1})
+
+ by defining :math:`\hat{E} = I + \hat{A} \hat{D}^{-1} \hat{A}^H`. The
+ right hand side is much cheaper to compute because the only matrix
+ inversions involve :math:`\hat{D}`, which is diagonal, and
+ :math:`\hat{E}`, which is a weighted inner product of
+ :math:`\hat{A}^H` and :math:`\hat{A}`.
+ """
+
+ def __init__(self, A: ComposedLinearOperator, D: CircularConvolve):
+ r"""
+ Args:
+ A: Operator :math:`A`.
+ D: Operator :math:`D`.
+ """
+ if not isinstance(A, ComposedLinearOperator):
+ raise TypeError(
+ f"Operator A is required to be a ComposedLinearOperator; got a {type(A)}."
+ )
+ if not isinstance(A.A, Sum) or not isinstance(A.B, CircularConvolve):
+ raise TypeError(
+ "Operator A is required to be a composition of Sum and CircularConvolve"
+ f"linear operators; got a composition of {type(A.A)} and {type(A.B)}."
+ )
+
+ self.A = A
+ self.D = D
+ self.sum_axis = A.A.kwargs["axis"]
+ if not isinstance(self.sum_axis, int):
+ raise ValueError(
+ "Sum component of operator A must sum over a single axis of its input."
+ )
+ self.fft_axes = A.B.x_fft_axes
+ self.real_result = is_real_dtype(D.input_dtype)
+
+ Ahat = A.B.h_dft
+ Dhat = D.h_dft
+ self.AHEinv = Ahat.conj() / (
+ 1.0 + snp.sum(Ahat * (Ahat.conj() / Dhat), axis=self.sum_axis, keepdims=True)
+ )
+
+ def solve(self, b: Array) -> Array:
+ r"""Solve the linear system.
+
+ Solve the linear system with right hand side :math:`\mb{b}`.
+
+ Args:
+ b: Array :math:`\mathbf{b}`.
+
+ Returns:
+ Solution to the linear system.
+ """
+ assert isinstance(self.A.B, CircularConvolve)
+
+ Ahat = self.A.B.h_dft
+ Dhat = self.D.h_dft
+ bhat = snp.fft.fftn(b, axes=self.fft_axes)
+ xhat = (
+ bhat - (self.AHEinv * (snp.sum(Ahat * bhat / Dhat, axis=self.sum_axis, keepdims=True)))
+ ) / Dhat
+ x = snp.fft.ifftn(xhat, axes=self.fft_axes)
+ if self.real_result:
+ x = x.real
+
+ return x
+
+ def accuracy(self, x: Array, b: Array) -> float:
+ r"""Compute solution relative residual.
+
+ Args:
+ x: Array :math:`\mathbf{x}` (solution).
+ b: Array :math:`\mathbf{b}` (right hand side of linear system).
+
+ Returns:
+ Relative residual of solution.
+ """
+ return rel_res(self.A.gram_op(x) + self.D(x), b)
diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py
index 282caa7d1..ce51196b6 100644
--- a/scico/test/linop/test_linop.py
+++ b/scico/test/linop/test_linop.py
@@ -374,6 +374,29 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator):
b = Dnew @ x
snp.testing.assert_allclose(a, b, rtol=1e-5)
+ @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
+ @pytest.mark.parametrize("input_shape1", input_shapes)
+ @pytest.mark.parametrize("input_shape2", input_shapes)
+ def test_matmul(self, input_shape1, input_shape2, diagonal_dtype):
+
+ diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key)
+ diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key)
+ x, key = randn(input_shape1, dtype=diagonal_dtype, key=key)
+
+ D1 = linop.Diagonal(diagonal=diagonal1)
+ D2 = linop.Diagonal(diagonal=diagonal2)
+
+ if input_shape1 != input_shape2:
+ with pytest.raises(ValueError):
+ D3 = D1 @ D2
+ else:
+ D3 = D1 @ D2
+ assert isinstance(D3, linop.Diagonal)
+ a = D3 @ x
+ D4 = linop.Diagonal(diagonal1 * diagonal2)
+ b = D4 @ x
+ snp.testing.assert_allclose(a, b, rtol=1e-5)
+
@pytest.mark.parametrize("operator", [op.add, op.sub])
def test_binary_op_mismatch(self, operator):
diagonal_dtype = np.float32
@@ -418,6 +441,40 @@ def test_scalar_left(self, operator):
np.testing.assert_allclose(scaled_D @ x, operator(D @ x, scalar), rtol=5e-5)
+ @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
+ def test_gram_op(self, diagonal_dtype):
+
+ input_shape = (7,)
+ diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)
+
+ D1 = linop.Diagonal(diagonal=diagonal)
+ D2 = D1.gram_op
+ D3 = D1.H @ D1
+ assert isinstance(D3, linop.Diagonal)
+ snp.testing.assert_allclose(D2.diagonal, D3.diagonal, rtol=1e-6)
+
+ @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
+ @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2])
+ def test_norm(self, diagonal_dtype, ord):
+
+ input_shape = (5,)
+ diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)
+
+ D1 = linop.Diagonal(diagonal=diagonal)
+ D2 = snp.diag(diagonal)
+ n1 = D1.norm(ord=ord)
+ n2 = snp.linalg.norm(D2, ord=ord)
+ snp.testing.assert_allclose(n1, n2, rtol=1e-6)
+
+ def test_norm_except(self):
+
+ input_shape = (5,)
+ diagonal, key = randn(input_shape, dtype=np.float32, key=self.key)
+
+ D = linop.Diagonal(diagonal=diagonal)
+ with pytest.raises(ValueError):
+ n = D.norm(ord=3)
+
def test_adj_lazy():
dtype = np.float32
diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py
index b100e3f40..b246c7416 100644
--- a/scico/test/optimize/test_admm.py
+++ b/scico/test/optimize/test_admm.py
@@ -5,7 +5,7 @@
import pytest
import scico.numpy as snp
-from scico import functional, linop, loss, metric, random
+from scico import functional, linop, loss, metric, operator, random
from scico.optimize import ADMM
from scico.optimize.admm import (
CircularConvolveSolver,
@@ -13,6 +13,7 @@
G0BlockCircularConvolveSolver,
GenericSubproblemSolver,
LinearSubproblemSolver,
+ MatrixSubproblemSolver,
)
@@ -72,6 +73,37 @@ def callback(obj):
with pytest.raises(ValueError):
admm_.solve()
+ @pytest.mark.parametrize(
+ "solver", [LinearSubproblemSolver, MatrixSubproblemSolver, CircularConvolveSolver]
+ )
+ def test_admm_aux(self, solver):
+ maxiter = 2
+ 蟻 = 1e-1
+ A = operator.Abs(self.y.shape)
+ f = loss.SquaredL2Loss(y=self.y, A=A)
+ g = functional.DnCNN()
+ C = linop.Identity(self.y.shape)
+
+ with pytest.raises(TypeError):
+ admm_ = ADMM(
+ f=f,
+ g_list=[g],
+ C_list=[C],
+ rho_list=[蟻],
+ maxiter=maxiter,
+ subproblem_solver=solver(),
+ )
+
+ with pytest.raises(TypeError):
+ admm_ = ADMM(
+ f=g,
+ g_list=[g],
+ C_list=[C],
+ rho_list=[蟻],
+ maxiter=maxiter,
+ subproblem_solver=solver(),
+ )
+
class TestReal:
def setup_method(self, method):
@@ -204,7 +236,7 @@ def setup_method(self, method):
self.grdA = lambda x: (饾浖 * Amx.T @ (W * Amx) + 位 * Bmx.T @ Bmx) @ x
self.grdb = 饾浖 * Amx.T @ (W[:, 0] * y)
- def test_admm_quadratic(self):
+ def test_admm_quadratic_linear(self):
maxiter = 100
蟻 = 1e0
A = linop.MatrixOperator(self.Amx)
@@ -225,6 +257,27 @@ def test_admm_quadratic(self):
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4
+ def test_admm_quadratic_matrix(self):
+ maxiter = 50
+ 蟻 = 1e0
+ A = linop.MatrixOperator(self.Amx)
+ f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.饾浖 / 2.0)
+ g_list = [(self.位 / 2) * functional.SquaredL2Norm()]
+ C_list = [linop.MatrixOperator(self.Bmx)]
+ rho_list = [蟻]
+ admm_ = ADMM(
+ f=f,
+ g_list=g_list,
+ C_list=C_list,
+ rho_list=rho_list,
+ maxiter=maxiter,
+ itstat_options={"display": False},
+ x0=A.adj(self.y),
+ subproblem_solver=MatrixSubproblemSolver(),
+ )
+ x = admm_.solve()
+ assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
+
class TestComplex:
def setup_method(self, method):
@@ -234,12 +287,12 @@ def setup_method(self, method):
# Set up arrays for problem argmin (饾浖/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2
Amx, key = random.randn((MA, N), dtype=np.complex64, key=None)
Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key)
- y = np.random.randn(MA)
+ y, key = random.randn((MA,), dtype=np.complex64, key=key)
饾浖 = 1.0 / 3.0
位 = 1e0
self.Amx = Amx
self.Bmx = Bmx
- self.y = jax.device_put(y)
+ self.y = y
self.饾浖 = 饾浖
self.位 = 位
# Solution of problem is given by linear system (饾浖 A^T A + 位 B^T B) x = A^T y
@@ -267,7 +320,7 @@ def test_admm_generic(self):
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3
- def test_admm_quadratic(self):
+ def test_admm_quadratic_linear(self):
maxiter = 50
蟻 = 1e0
A = linop.MatrixOperator(self.Amx)
@@ -290,6 +343,27 @@ def test_admm_quadratic(self):
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4
+ def test_admm_quadratic_matrix(self):
+ maxiter = 50
+ 蟻 = 1e0
+ A = linop.MatrixOperator(self.Amx)
+ f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0)
+ g_list = [(self.位 / 2) * functional.SquaredL2Norm()]
+ C_list = [linop.MatrixOperator(self.Bmx)]
+ rho_list = [蟻]
+ admm_ = ADMM(
+ f=f,
+ g_list=g_list,
+ C_list=C_list,
+ rho_list=rho_list,
+ maxiter=maxiter,
+ itstat_options={"display": False},
+ x0=A.adj(self.y),
+ subproblem_solver=MatrixSubproblemSolver(),
+ )
+ x = admm_.solve()
+ assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
+
class TestCircularConvolveSolve:
def setup_method(self, method):
@@ -369,7 +443,7 @@ def test_fblock_init(self):
itstat_options={"display": False},
subproblem_solver=FBlockCircularConvolveSolver(),
)
- with pytest.raises(ValueError):
+ with pytest.raises(TypeError):
slvr = ADMM(
f=loss.PoissonLoss(y=self.y),
g_list=self.g_list,
@@ -378,7 +452,7 @@ def test_fblock_init(self):
itstat_options={"display": False},
subproblem_solver=FBlockCircularConvolveSolver(),
)
- with pytest.raises(ValueError):
+ with pytest.raises(TypeError):
slvr = ADMM(
f=loss.SquaredL2Loss(y=self.y, A=self.A.A),
g_list=self.g_list,
@@ -398,7 +472,7 @@ def test_g0block_init(self):
itstat_options={"display": False},
subproblem_solver=G0BlockCircularConvolveSolver(),
)
- with pytest.raises(ValueError):
+ with pytest.raises(TypeError):
slvr = ADMM(
f=functional.ZeroFunctional(),
g_list=[loss.PoissonLoss(y=self.y)],
@@ -407,7 +481,7 @@ def test_g0block_init(self):
itstat_options={"display": False},
subproblem_solver=G0BlockCircularConvolveSolver(),
)
- with pytest.raises(ValueError):
+ with pytest.raises(TypeError):
slvr = ADMM(
f=functional.ZeroFunctional(),
g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list,
diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py
index abe3745a2..d5b179b62 100644
--- a/scico/test/test_solver.py
+++ b/scico/test/test_solver.py
@@ -6,7 +6,7 @@
import pytest
import scico.numpy as snp
-from scico import linop, random, solver
+from scico import linop, metric, random, solver
class TestSet:
@@ -294,3 +294,78 @@ def test_golden():
f = lambda x, c: (x - c) ** 2
x = solver.golden(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5)
assert snp.max(snp.abs(x - c)) <= 1e-5
+
+
+@pytest.mark.parametrize("cho_factor", [True, False])
+@pytest.mark.parametrize("wide", [True, False])
+@pytest.mark.parametrize("weighted", [True, False])
+@pytest.mark.parametrize("alpha", [1e-1, 1e1])
+def test_solve_atai(cho_factor, wide, weighted, alpha):
+ A, key = random.randn((5, 8), dtype=snp.float32)
+ if wide:
+ x0, key = random.randn((8,), key=key)
+ else:
+ A = A.T
+ x0, key = random.randn((5,), key=key)
+
+ if weighted:
+ W, key = random.randn((A.shape[0],), key=key)
+ W = snp.abs(W)
+ Wa = W[:, snp.newaxis]
+ else:
+ W = None
+ Wa = snp.array([1.0])[:, snp.newaxis]
+
+ D = alpha * snp.ones((A.shape[1],))
+ ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1])
+ b = ATAD @ x0
+ slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor)
+ x1 = slv.solve(b)
+ assert metric.rel_res(x0, x1) < 5e-5
+
+
+@pytest.mark.parametrize("cho_factor", [True, False])
+@pytest.mark.parametrize("wide", [True, False])
+@pytest.mark.parametrize("alpha", [1e-1, 1e1])
+def test_solve_aati(cho_factor, wide, alpha):
+ A, key = random.randn((5, 8), dtype=snp.float32)
+ if wide:
+ x0, key = random.randn((5,), key=key)
+ else:
+ A = A.T
+ x0, key = random.randn((8,), key=key)
+
+ D = alpha * snp.ones((A.shape[0],))
+ AATD = A @ A.T + alpha * snp.identity(A.shape[0])
+ b = AATD @ x0
+ slv = solver.ATADSolver(A.T, D)
+ x1 = slv.solve(b)
+ assert metric.rel_res(x0, x1) < 5e-5
+
+
+@pytest.mark.parametrize("cho_factor", [True, False])
+@pytest.mark.parametrize("wide", [True, False])
+@pytest.mark.parametrize("vector", [True, False])
+def test_solve_atad(cho_factor, wide, vector):
+ A, key = random.randn((5, 8), dtype=snp.float32)
+ if wide:
+ D, key = random.randn((8,), key=key)
+ if vector:
+ x0, key = random.randn((8,), key=key)
+ else:
+ x0, key = random.randn((8, 3), key=key)
+ else:
+ A = A.T
+ D, key = random.randn((5,), key=key)
+ if vector:
+ x0, key = random.randn((5,), key=key)
+ else:
+ x0, key = random.randn((5, 3), key=key)
+
+ D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU
+ ATAD = A.T @ A + snp.diag(D)
+ b = ATAD @ x0
+ slv = solver.ATADSolver(A, D, cho_factor=cho_factor)
+ x1 = slv.solve(b)
+ assert metric.rel_res(x0, x1) < 5e-5
+ assert slv.accuracy(x1, b) < 5e-5
diff --git a/scico/test/test_util.py b/scico/test/test_util.py
index 5b58c0bf5..5bbdf002d 100644
--- a/scico/test/test_util.py
+++ b/scico/test/test_util.py
@@ -8,7 +8,32 @@
import pytest
import scico.numpy as snp
-from scico.util import ContextTimer, Timer, check_for_tracer, partial, url_get
+from scico.util import (
+ ContextTimer,
+ Timer,
+ check_for_tracer,
+ partial,
+ rgetattr,
+ rsetattr,
+ url_get,
+)
+
+
+def test_rattr():
+ class A:
+ class B:
+ c = 0
+
+ b = B()
+
+ a = A()
+ rsetattr(a, "b.c", 1)
+ assert rgetattr(a, "b.c") == 1
+
+ assert rgetattr(a, "c.d", 10) == 10
+
+ with pytest.raises(AttributeError):
+ assert rgetattr(a, "c.d")
def test_partial_pos():
diff --git a/scico/util.py b/scico/util.py
index 73982fd0b..d57c8efc6 100644
--- a/scico/util.py
+++ b/scico/util.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright (C) 2020-2022 by SCICO Developers
+# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
@@ -14,7 +14,7 @@
import socket
import urllib.error as urlerror
import urllib.request as urlrequest
-from functools import wraps
+from functools import reduce, wraps
from timeit import default_timer as timer
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
@@ -23,6 +23,44 @@
from jax.interpreters.partial_eval import DynamicJaxprTracer
+def rgetattr(obj: object, name: str, default: Optional[Any] = None) -> Any:
+ """Recursive version of :func:`getattr`.
+
+ Args:
+ obj: Object with the attribute to be accessed.
+ name: Path to object in with components delimited by a "."
+ character.
+ default: Default value to be returned if the attribute does not
+ exist.
+
+ Returns:
+ Attribute value of default if attribute does not exist.
+ """
+
+ try:
+ return reduce(getattr, name.split("."), obj)
+ except AttributeError as e:
+ if default is not None:
+ return default
+ else:
+ raise e
+
+
+def rsetattr(obj: object, name: str, value: Any):
+ """Recursive version of :func:`setattr`.
+
+ Args:
+ obj: Object with the attribute to be set.
+ name: Path to object in with components delimited by a "."
+ character.
+ value: Value to which the attribute is to be set.
+ """
+
+ # See goo.gl/BVJ7MN
+ path = name.split(".")
+ setattr(reduce(getattr, path[:-1], obj), path[-1], value)
+
+
def partial(func: Callable, indices: Sequence, *fixargs: Any, **fixkwargs: Any) -> Callable:
"""Flexible partial function creation.