diff --git a/.github/codecov.yml b/.github/codecov.yml index 9de3e3f04..2134d07fc 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -1,3 +1,11 @@ coverage: + precision: 2 + round: nearest + range: "80...100" + status: + project: + default: + target: auto + threshold: 0.05% patch: false diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_admm.py index ed500f60e..32f3eafde 100644 --- a/examples/scripts/sparsecode_admm.py +++ b/examples/scripts/sparsecode_admm.py @@ -24,7 +24,7 @@ import jax from scico import functional, linop, loss, plot -from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.optimize.admm import ADMM, MatrixSubproblemSolver from scico.util import device_info """ @@ -67,7 +67,7 @@ rho_list=rho_list, x0=A.adj(y), maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(), + subproblem_solver=MatrixSubproblemSolver(), itstat_options={"display": True, "period": 10}, ) diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index 30cd21949..0e9251f57 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -42,9 +42,9 @@ def __init__( r""" Args: diagonal: Diagonal elements of this :class:`LinearOperator`. - input_shape: Shape of input array. By default, equal to + input_shape: Shape of input array. By default, equal to `diagonal.shape`, but may also be set to a shape that is - broadcast-compatiable with `diagonal.shape`. + broadcast-compatible with `diagonal.shape`. input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. """ @@ -64,7 +64,7 @@ def __init__( elif isinstance(diagonal, BlockArray): raise ValueError("Parameter diagonal was a BlockArray but input_shape was not nested.") else: - raise ValueError("Parameter diagonal was a not BlockArray but input_shape was nested.") + raise ValueError("Parameter diagonal was not a BlockArray but input_shape was nested.") super().__init__( input_shape=input_shape, @@ -77,6 +77,29 @@ def __init__( def _eval(self, x): return x * self.diagonal + @property + def T(self) -> Diagonal: + """Transpose of this :class:`Diagonal`.""" + return self + + def conj(self) -> Diagonal: + """Complex conjugate of this :class:`Diagonal`.""" + return Diagonal(diagonal=self.diagonal.conj()) + + @property + def H(self) -> Diagonal: + """Hermitian transpose of this :class:`Diagonal`.""" + return self.conj() + + @property + def gram_op(self) -> Diagonal: + """Gram operator of this :class:`Diagonal`. + + Return a new :class:`Diagonal` :code:`G` such that + :code:`G(x) = A.adj(A(x)))`. + """ + return Diagonal(diagonal=self.diagonal.conj() * self.diagonal) + @partial(_wrap_add_sub, op=operator.add) def __add__(self, other): if self.diagonal.shape == other.diagonal.shape: @@ -101,6 +124,41 @@ def __rmul__(self, scalar): def __truediv__(self, scalar): return Diagonal(diagonal=self.diagonal / scalar) + def __matmul__(self, other): + # self @ other + if isinstance(other, Diagonal): + if self.shape == other.shape: + return Diagonal(diagonal=self.diagonal * other.diagonal) + + raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") + + else: + return self(other) + + def norm(self, ord=None): # pylint: disable=W0622 + """Compute the matrix norm of the diagonal operator. + + Valid values of `ord` and the corresponding norm definition + are those listed under "norm for matrices" in the + :func:`scico.numpy.linalg.norm` documentation. + """ + ordfunc = { + "fro": lambda x: snp.linalg.norm(x), + "nuc": lambda x: snp.sum(snp.abs(x)), + -snp.inf: lambda x: snp.abs(x).min(), + snp.inf: lambda x: snp.abs(x).max(), + } + mord = ord + if mord is None: + mord = "fro" + elif mord in (-1, -2): + mord = -snp.inf + elif mord in (1, 2): + mord = snp.inf + if mord not in ordfunc: + raise ValueError(f"Invalid value {ord} for parameter ord.") + return ordfunc[mord](self.diagonal) + class Identity(Diagonal): """Identity operator.""" diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index e2d35e3c7..0f6f21daa 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -172,7 +172,7 @@ def __mul__(self, other): raise TypeError(f"Operation __mul__ not defined between {type(self)} and {type(other)}.") def __rmul__(self, other): - # Multiplication is commutative + # multiplication is commutative return self * other def __truediv__(self, other): @@ -255,6 +255,6 @@ def gram_op(self): def norm(self, ord=None, axis=None, keepdims=False): # pylint: disable=W0622 """Compute the norm of the dense matrix `self.A`. - Call :func:`scico.numpy.norm` on the dense matrix `self.A`. + Call :func:`scico.numpy.linalg.norm` on the dense matrix `self.A`. """ return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims) diff --git a/scico/optimize/_admm.py b/scico/optimize/_admm.py index ec4fa8edf..55862f0cb 100644 --- a/scico/optimize/_admm.py +++ b/scico/optimize/_admm.py @@ -25,6 +25,7 @@ G0BlockCircularConvolveSolver, GenericSubproblemSolver, LinearSubproblemSolver, + MatrixSubproblemSolver, SubproblemSolver, ) from ._common import Optimizer @@ -189,7 +190,7 @@ def _itstat_extra_fields(self): ) elif ( type(self.subproblem_solver) - in [FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver] + in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver] and self.subproblem_solver.check_solve ): itstat_fields.update({"Slv Res": "%9.3e"}) diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 5d670d7e6..7a0ceb0b2 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -23,14 +23,15 @@ from scico.linop import ( CircularConvolve, ComposedLinearOperator, + Diagonal, Identity, LinearOperator, - Sum, + MatrixOperator, ) from scico.loss import SquaredL2Loss -from scico.metric import rel_res from scico.numpy import Array, BlockArray from scico.numpy.util import ensure_on_device, is_real_dtype +from scico.solver import ATADSolver, ConvATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -126,8 +127,8 @@ class LinearSubproblemSolver(SubproblemSolver): for the case where :code:`f` is an :math:`\ell_2` or weighted :math:`\ell_2` norm, and :code:`f.A` is a linear operator, so that the subproblem involves solving a linear equation. This requires that - `f.functional` be an instance of :class:`.SquaredL2Loss` and for - the forward operator :code:`f.A` to be an instance of + :code:`f.functional` be an instance of :class:`.SquaredL2Loss` and + for the forward operator :code:`f.A` to be an instance of :class:`.LinearOperator`. The :math:`\mb{x}`-update step is @@ -139,7 +140,7 @@ class LinearSubproblemSolver(SubproblemSolver): \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, where :math:`W` a weighting :class:`.Diagonal` operator - or an :class:`.Identity` operator (i.e. no weighting). + or an :class:`.Identity` operator (i.e., no weighting). This update step reduces to the solution of the linear system .. math:: @@ -200,19 +201,19 @@ def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str def internal_init(self, admm: soa.ADMM): if admm.f is not None: if not isinstance(admm.f, SquaredL2Loss): - raise ValueError( + raise TypeError( "LinearSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, LinearOperator): - raise ValueError( + raise TypeError( "LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " f"got {type(admm.f.A)}." ) super().internal_init(admm) - # Set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve + # Set lhs_op = \sum_i rho_i * Ci.H @ Ci # Use reduce as the initialization of this sum is messy otherwise lhs_op = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] @@ -266,6 +267,110 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x +class MatrixSubproblemSolver(LinearSubproblemSolver): + r"""Solver for quadratic functionals involving matrix operators. + + Solver for the case in which :math:`f` is a quadratic function of + :math:`\mb{x}`, and :math:`A` and all the :math:`C_i` are diagonal + or matrix operators. It is a specialization of + :class:`.LinearSubproblemSolver`. + + As for :class:`.LinearSubproblemSolver`, the :math:`\mb{x}`-update + step is + + .. math:: + + \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} + \norm{\mb{y} - A \mb{x}}_W^2 + \sum_i \frac{\rho_i}{2} + \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, + + where :math:`W` is a weighting :class:`.Diagonal` operator + or an :class:`.Identity` operator (i.e., no weighting). + This update step reduces to the solution of the linear system + + .. math:: + + \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right) + \mb{x}^{(k+1)} = \; + A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - + \mb{u}^{(k)}_i) \;, + + which is solved by factorization of the left hand side of the + equation, using :class:`.ATADSolver`. + + + Attributes: + admm (:class:`.ADMM`): ADMM solver object to which the solver is + attached. + solve_kwargs (dict): Dictionary of arguments for solver + :class:`.ATADSolver` initialization. + """ + + def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None): + """Initialize a :class:`MatrixSubproblemSolver` object. + + Args: + check_solve: If ``True``, compute solver accuracy after each + solve. + solve_kwargs: Dictionary of arguments for solver + :class:`.ATADSolver` initialization. + """ + self.check_solve = check_solve + default_solve_kwargs = {"cho_factor": False} + if solve_kwargs: + default_solve_kwargs.update(solve_kwargs) + self.solve_kwargs = default_solve_kwargs + + def internal_init(self, admm: soa.ADMM): + if admm.f is not None: + if not isinstance(admm.f, SquaredL2Loss): + raise TypeError( + "MatrixSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; " + f"got {type(admm.f)}." + ) + if not isinstance(admm.f.A, (Diagonal, MatrixOperator)): + raise TypeError( + "MatrixSubproblemSolver requires f.A to be a Diagonal or MatrixOperator; " + f"got {type(admm.f.A)}." + ) + for i, Ci in enumerate(admm.C_list): + if not isinstance(Ci, (Diagonal, MatrixOperator)): + raise TypeError( + "MatrixSubproblemSolver requires C[{i}] to be a Diagonal or MatrixOperator; " + f"got {type(Ci)}." + ) + + super().internal_init(admm) + + if admm.f is None: + A = snp.zeros(admm.C_list[0].input_shape[0], dtype=admm.C_list[0].input_dtype) + W = None + else: + A = admm.f.A + W = 2.0 * self.admm.f.scale * admm.f.W # type: ignore + + Csum = reduce( + lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] + ) + self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs) + + def solve(self, x0: Array) -> Array: + """Solve the ADMM step. + + Args: + x0: Initial value (ignored). + + Returns: + Computed solution. + """ + rhs = self.compute_rhs() + x = self.solver.solve(rhs) + if self.check_solve: + self.accuracy = self.solver.accuracy(x, rhs) + + return x + + class CircularConvolveSolver(LinearSubproblemSolver): r"""Solver for linear operators diagonalized in the DFT domain. @@ -293,12 +398,12 @@ def __init__(self): def internal_init(self, admm: soa.ADMM): if admm.f is not None: if not isinstance(admm.f, SquaredL2Loss): - raise ValueError( + raise TypeError( "CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, (CircularConvolve, Identity)): - raise ValueError( + raise TypeError( "CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve " f"or scico.linop.Identity; got {type(admm.f.A)}." ) @@ -454,35 +559,19 @@ def internal_init(self, admm: soa.ADMM): raise ValueError("FBlockCircularConvolveSolver does not allow f to be None.") else: if not isinstance(admm.f, SquaredL2Loss): - raise ValueError( + raise TypeError( "FBlockCircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; " f"got {type(admm.f)}." ) if not isinstance(admm.f.A, ComposedLinearOperator): - raise ValueError( + raise TypeError( "FBlockCircularConvolveSolver requires f.A to be a composition of Sum " f"and CircularConvolve linear operators; got {type(admm.f.A)}." ) - if not isinstance(admm.f.A.A, Sum) or not isinstance(admm.f.A.B, CircularConvolve): - raise ValueError( - "FBlockCircularConvolveSolver requires f.A to be a composition of Sum " - "and CircularConvolve linear operators; got a composition of " - f"{type(admm.f.A.A)} and {type(admm.f.A.B)}." - ) - self.sum_axis = admm.f.A.A.kwargs["axis"] - if not isinstance(self.sum_axis, int): - raise ValueError( - "FBlockCircularConvolveSolver requires the Sum operator component " - "of f.A to sum over a single axis of its input." - ) - super().internal_init(admm) assert isinstance(self.admm.f, SquaredL2Loss) assert isinstance(self.admm.f.A, ComposedLinearOperator) - assert isinstance(self.admm.f.A.B, CircularConvolve) - - self.real_result = is_real_dtype(admm.C_list[0].input_dtype) # All of the C operators are assumed to be linear and shift invariant # but this is not checked. @@ -490,12 +579,8 @@ def internal_init(self, admm: soa.ADMM): rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list, admm.C_list) ] - self.D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale) - A = self.admm.f.A.B - self.AHEinv = A.h_dft.conj() / ( - 1.0 - + snp.sum(A.h_dft * (A.h_dft.conj() / self.D.h_dft), axis=self.sum_axis, keepdims=True) - ) + D = reduce(lambda a, b: a + b, c_gram_list) / (2.0 * self.admm.f.scale) + self.solver = ConvATADSolver(self.admm.f.A, D) def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Solve the ADMM step. @@ -507,27 +592,11 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Computed solution. """ assert isinstance(self.admm.f, SquaredL2Loss) - assert isinstance(self.admm.f.A, ComposedLinearOperator) - assert isinstance(self.admm.f.A.B, CircularConvolve) rhs = self.compute_rhs() / (2.0 * self.admm.f.scale) - fft_axes = self.admm.f.A.B.x_fft_axes - rhs_dft = snp.fft.fftn(rhs, axes=fft_axes) - A = self.admm.f.A.B - x_dft = ( - rhs_dft - - ( - self.AHEinv - * (snp.sum(A.h_dft * rhs_dft / self.D.h_dft, axis=self.sum_axis, keepdims=True)) - ) - ) / self.D.h_dft - x = snp.fft.ifftn(x_dft, axes=fft_axes) - if self.real_result: - x = x.real - + x = self.solver.solve(rhs) if self.check_solve: - lhs = self.admm.f.A.gram_op(x) + self.D(x) - self.accuracy = rel_res(lhs, rhs) + self.accuracy = self.solver.accuracy(x, rhs) return x @@ -537,7 +606,7 @@ class G0BlockCircularConvolveSolver(SubproblemSolver): domain. Specialization of :class:`.LinearSubproblemSolver` for the case - where :math:`f = 0` (i.e. :code:`f` is a :class:`.ZeroFunctional`), + where :math:`f = 0` (i.e, :code:`f` is a :class:`.ZeroFunctional`), :math:`g_1` is an instance of :class:`.SquaredL2Loss`, :math:`C_1` is a composition of a :class:`.Sum` operator an a :class:`.CircularConvolve` operator. The former must sum over the @@ -661,37 +730,20 @@ def internal_init(self, admm: soa.ADMM): "G0BlockCircularConvolveSolver requires f to be None or a ZeroFunctional" ) if not isinstance(admm.g_list[0], SquaredL2Loss): - raise ValueError( + raise TypeError( "G0BlockCircularConvolveSolver requires g_1 to be a scico.loss.SquaredL2Loss; " f"got {type(admm.g_list[0])}." ) if not isinstance(admm.C_list[0], ComposedLinearOperator): - raise ValueError( + raise TypeError( "G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum " f"and CircularConvolve linear operators; got {type(admm.C_list[0])}." ) - if not isinstance(admm.C_list[0].A, Sum) or not isinstance( - admm.C_list[0].B, CircularConvolve - ): - raise ValueError( - "G0BlockCircularConvolveSolver requires C_1 to be a composition of Sum " - "and CircularConvolve linear operators; got a composition of " - f"{type(admm.C_list[0].A)} and {type(admm.C_list[0].B)}." - ) - self.sum_axis = admm.C_list[0].A.kwargs["axis"] - if not isinstance(self.sum_axis, int): - raise ValueError( - "G0BlockCircularConvolveSolver requires the Sum operator component " - "of C_1 to sum over a single axis of its input." - ) super().internal_init(admm) assert isinstance(self.admm.g_list[0], SquaredL2Loss) assert isinstance(self.admm.C_list[0], ComposedLinearOperator) - assert isinstance(self.admm.C_list[0].B, CircularConvolve) - - self.real_result = is_real_dtype(admm.C_list[0].input_dtype) # All of the C operators are assumed to be linear and shift invariant # but this is not checked. @@ -699,14 +751,10 @@ def internal_init(self, admm: soa.ADMM): rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims) for rho, C in zip(admm.rho_list[1:], admm.C_list[1:]) ] - self.D = reduce(lambda a, b: a + b, c_gram_list) / ( + D = reduce(lambda a, b: a + b, c_gram_list) / ( 2.0 * self.admm.g_list[0].scale * admm.rho_list[0] ) - A = self.admm.C_list[0].B - self.AHEinv = A.h_dft.conj() / ( - 1.0 - + snp.sum(A.h_dft * (A.h_dft.conj() / self.D.h_dft), axis=self.sum_axis, keepdims=True) - ) + self.solver = ConvATADSolver(self.admm.C_list[0], D) def compute_rhs(self) -> Union[Array, BlockArray]: r"""Compute the right hand side of the linear equation to be solved. @@ -720,7 +768,7 @@ def compute_rhs(self) -> Union[Array, BlockArray]: ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. Returns: - Computed solution. + Right hand side of the linear equation. """ assert isinstance(self.admm.g_list[0], SquaredL2Loss) @@ -746,26 +794,10 @@ def solve(self, x0: Union[Array, BlockArray]) -> Union[Array, BlockArray]: Computed solution. """ assert isinstance(self.admm.g_list[0], SquaredL2Loss) - assert isinstance(self.admm.C_list[0], ComposedLinearOperator) - assert isinstance(self.admm.C_list[0].B, CircularConvolve) rhs = self.compute_rhs() / (2.0 * self.admm.g_list[0].scale * self.admm.rho_list[0]) - fft_axes = self.admm.C_list[0].B.x_fft_axes - rhs_dft = snp.fft.fftn(rhs, axes=fft_axes) - A = self.admm.C_list[0].B - x_dft = ( - rhs_dft - - ( - self.AHEinv - * (snp.sum(A.h_dft * rhs_dft / self.D.h_dft, axis=self.sum_axis, keepdims=True)) - ) - ) / self.D.h_dft - x = snp.fft.ifftn(x_dft, axes=fft_axes) - if self.real_result: - x = x.real - + x = self.solver.solve(rhs) if self.check_solve: - lhs = self.admm.C_list[0].gram_op(x) + self.D(x) - self.accuracy = rel_res(lhs, rhs) + self.accuracy = self.solver.accuracy(x, rhs) return x diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index ee9a126ed..7f64cc2d6 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -14,6 +14,7 @@ SubproblemSolver, GenericSubproblemSolver, LinearSubproblemSolver, + MatrixSubproblemSolver, CircularConvolveSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver, @@ -24,6 +25,7 @@ "SubproblemSolver", "GenericSubproblemSolver", "LinearSubproblemSolver", + "MatrixSubproblemSolver", "CircularConvolveSolver", "FBlockCircularConvolveSolver", "G0BlockCircularConvolveSolver", diff --git a/scico/solver.py b/scico/solver.py index d69c47247..5f0994246 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -5,24 +5,14 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""SciPy optimization algorithms. - -.. raw:: html - - - -This module provides scico interface wrappers for functions +"""Solver and optimization algorithms. + +This module provides a number of functions for solving linear systems and +optimization problems, some of which are used as subproblem solvers +within the iterations of the proximal algorithms in the +:mod:`scico.optimize` subpackage. + +This module also provides scico interface wrappers for functions from :mod:`scipy.optimize` since jax directly implements only a very limited subset of these functions (there is limited, experimental support for `L-BFGS-B `_), but only CG @@ -39,8 +29,8 @@ The wrapper also JIT compiles the function and gradient evaluations. -The functions provided in this module have a number of advantages and -disadvantages with respect to those in :mod:`jax.scipy.optimize`: +These wrapper functions have a number of advantages and disadvantages +with respect to those in :mod:`jax.scipy.optimize`: - This module provides many more algorithms than :mod:`jax.scipy.optimize`. @@ -52,7 +42,7 @@ - The solvers in this module can't be JIT compiled, and gradients cannot be taken through them. -In the future, this module may be replaced with a dependency on +In the future, these wrapper functions may be replaced with a dependency on `JAXopt `__. """ @@ -64,10 +54,20 @@ import jax import jax.experimental.host_callback as hcb +import jax.scipy.linalg as jsl -import scico.linop import scico.numpy as snp +from scico.linop import ( + CircularConvolve, + ComposedLinearOperator, + Diagonal, + LinearOperator, + MatrixOperator, + Sum, +) +from scico.metric import rel_res from scico.numpy import Array, BlockArray +from scico.numpy.util import is_real_dtype from scico.typing import BlockShape, DType, Shape from scipy import optimize as spopt @@ -336,7 +336,7 @@ def cg( - **info**: Dictionary containing diagnostic information. """ if x0 is None: - if isinstance(A, scico.linop.LinearOperator): + if isinstance(A, LinearOperator): x0 = snp.zeros(A.input_shape, b.dtype) else: raise ValueError("Parameter x0 must be specified if A is not a LinearOperator") @@ -418,11 +418,11 @@ def lstsq( - **x** : Solution array. - **info**: Dictionary containing diagnostic information. """ - if isinstance(A, scico.linop.LinearOperator): + if isinstance(A, LinearOperator): Aop = A else: assert x0 is not None - Aop = scico.linop.LinearOperator( + Aop = LinearOperator( input_shape=x0.shape, output_shape=b.shape, eval_fn=A, @@ -577,3 +577,371 @@ def golden( else: r = x return r + + +class ATADSolver: + r"""Solver for linear system involving a symmetric product plus a diagonal. + + Solve a linear system of the form + + .. math:: + + (A^T W A + D) \mb{x} = \mb{b} + + or + + .. math:: + + (A^T W A + D) X = B \;, + + where :math:`A \in \mbb{R}^{M \times N}`, + :math:`W \in \mbb{R}^{M \times M}` and + :math:`D \in \mbb{R}^{N \times N}`. The solution is computed by + factorization of matrix :math:`A^T W A + D` and solution via Gaussian + elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. + :math:`A W A^T` is smaller than :math:`A^T W A`), then + :math:`A W A^T + D` is factorized and the original problem is solved + via the Woodbury matrix identity + + .. math:: + + (E + U C V)^{-1} = E^{-1} - E^{-1} U (C^{-1} + V E^{-1} U)^{-1} + V E^{-1} \;. + + Setting + + .. math:: + + E &= D \\ + U &= A^T \\ + C &= W \\ + V &= A + + we have + + .. math:: + + (D + A^T W A)^{-1} = D^{-1} - D^{-1} A^T (W^{-1} + A D^{-1} A^T)^{-1} A + D^{-1} + + which can be simplified to + + .. math:: + + (D + A^T W A)^{-1} = D^{-1} (I - A^T G^{-1} A D^{-1}) + + by defining :math:`G = W^{-1} + A D^{-1} A^T`. We therefore have that + + .. math:: + + \mb{x} = (D + A^T W A)^{-1} \mb{b} = D^{-1} (I - A^T G^{-1} A + D^{-1}) \mb{b} \;. + + If we have a Cholesky factorization of :math:`G`, e.g. + :math:`G = L L^T`, we can define + + .. math:: + + \mb{w} = G^{-1} A D^{-1} \mb{b} + + so that + + .. math:: + + G \mb{w} &= A D^{-1} \mb{b} \\ + L L^T \mb{w} &= A D^{-1} \mb{b} \;. + + The Cholesky factorization can be exploited by solving for + :math:`\mb{z}` in + + .. math:: + + L \mb{z} = A D^{-1} \mb{b} + + and then for :math:`\mb{w}` in + + .. math:: + + L^T \mb{w} = \mb{z} \;, + + so that + + .. math:: + + \mb{x} = D^{-1} \mb{b} - D^{-1} A^T \mb{w} \;. + + (Functions :func:`~jax.scipy.linalg.cho_solve` and + :func:`~jax.scipy.linalg.lu_solve` allow direct solution for + :math:`\mb{w}` without the two-step procedure described here.) A + Cholesky factorization should only be used when :math:`G` is + positive-definite (e.g. :math:`D` is diagonal and positive); if not, + an LU factorization should be used. + + Complex-valued problems are also supported, in which case the + transpose :math:`\cdot^T` in the equations above should be taken to + represent the conjugate transpose. + + To solve problems directly involving a matrix of the form + :math:`A W A^T + D`, initialize with :code:`A.T` (or + :code:`A.T.conj()` for complex problems) instead of :code:`A`. + """ + + def __init__( + self, + A: Union[MatrixOperator, Array], + D: Union[MatrixOperator, Diagonal, Array], + W: Optional[Union[Diagonal, Array]] = None, + cho_factor: bool = False, + lower: bool = False, + check_finite: bool = True, + ): + r""" + Args: + A: Matrix :math:`A`. + D: Matrix :math:`D`. + W: Matrix :math:`W`. + cho_factor: Flag indicating whether to use Cholesky + (``True``) or LU (``False``) factorization. + lower: Flag indicating whether lower (``True``) or upper + (``False``) triangular factorization should be computed. + Only relevant to Cholesky factorization. + check_finite: Flag indicating whether the input array should + be checked for ``Inf`` and ``NaN`` values. + """ + if isinstance(A, MatrixOperator): + A = A.to_array() + if isinstance(D, MatrixOperator): + D = D.to_array() + elif isinstance(D, Diagonal): + D = D.diagonal + if W is None: + W = snp.ones(A.shape[0], dtype=A.dtype) + elif isinstance(W, Diagonal): + W = W.diagonal + self.A = A + self.D = D + self.W = W + self.cho_factor = cho_factor + self.lower = lower + self.check_finite = check_finite + + assert isinstance(W, Array) + N, M = A.shape + if N < M and D.ndim == 1: + G = snp.diag(1.0 / W) + A @ (A.T.conj() / D[:, snp.newaxis]) + else: + if D.ndim == 1: + G = A.T.conj() @ (W[:, snp.newaxis] * A) + snp.diag(D) + else: + G = A.T.conj() @ (W[:, snp.newaxis] * A) + D + + if cho_factor: + c, lower = jsl.cho_factor(G, lower=lower, check_finite=check_finite) + self.factor = (c, lower) + else: + lu, piv = jsl.lu_factor(G, check_finite=check_finite) + self.factor = (lu, piv) + + def solve(self, b: Array, check_finite: Optional[bool] = None) -> Array: + r"""Solve the linear system. + + Solve the linear system with right hand side :math:`\mb{b}` (`b` + is a vector) or :math:`B` (`b` is a 2d array). + + Args: + b: Vector :math:`\mathbf{b}` or matrix :math:`B`. + check_finite: Flag indicating whether the input array should + be checked for ``Inf`` and ``NaN`` values. If ``None``, + use the value selected on initialization. + + Returns: + Solution to the linear system. + """ + if check_finite is None: + check_finite = self.check_finite + if self.cho_factor: + fact_solve = lambda x: jsl.cho_solve(self.factor, x, check_finite=check_finite) + else: + fact_solve = lambda x: jsl.lu_solve(self.factor, x, trans=0, check_finite=check_finite) + + if b.ndim == 1: + D = self.D + else: + D = self.D[:, snp.newaxis] + N, M = self.A.shape + if N < M and self.D.ndim == 1: + w = fact_solve(self.A @ (b / D)) + x = (b - (self.A.T.conj() @ w)) / D + else: + x = fact_solve(b) + + return x + + def accuracy(self, x: Array, b: Array) -> float: + r"""Compute solution relative residual. + + Args: + x: Array :math:`\mathbf{x}` (solution). + b: Array :math:`\mathbf{b}` (right hand side of linear system). + + Returns: + Relative residual of solution. + """ + if b.ndim == 1: + D = self.D + else: + D = self.D[:, snp.newaxis] + assert isinstance(self.W, Array) + return rel_res(self.A.T.conj() @ (self.W[:, snp.newaxis] * self.A) @ x + D * x, b) + + +class ConvATADSolver: + r"""Solver for sum of convolutions plus diagonal linear system. + + Solve a linear system of the form + + .. math:: + + (A^H A + D) \mb{x} = \mb{b} + + where :math:`A` is a block-row operator with circulant blocks, i.e. it + can be written as + + .. math:: + + A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K} + \end{array} \right) \;, + + where all of the :math:`A_k` are circular convolution operators, and + :math:`D` is a circular convolution operator. This problem is most + easily solved in the DFT transform domain, where the circular + convolutions become diagonal operators. Denoting the frequency-domain + versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is + the frequency-domain version of :math:`\mb{x}`), the the problem can + be written as + + .. math:: + + (\hat{A}^H \hat{A} + \hat{D}) \hat{\mb{x}} = \hat{\mb{b}} \;, + + where + + .. math:: + + \hat{A} = \left( \begin{array}{cccc} \hat{A}_1 & \hat{A}_2 & + \ldots & \hat{A}_{K} \end{array} \right) \;, + + and :math:`\hat{D}` and all the :math:`\hat{A}_k` are diagonal + operators. + + This linear equation is computational expensive to solve because + the left hand side includes the term :math:`\hat{A}^H \hat{A}`, + which corresponds to the outer product of :math:`\hat{A}^H` + and :math:`\hat{A}`. A computationally efficient solution is possible, + however, by exploiting the Woodbury matrix identity + :cite:`wohlberg-2014-efficient` + + .. math:: + + (B + U C V)^{-1} = B^{-1} - B^{-1} U (C^{-1} + V B^{-1} U)^{-1} + V B^{-1} \;. + + Setting + + .. math:: + + B &= \hat{D} \\ + U &= \hat{A}^H \\ + C &= I \\ + V &= \hat{A} + + we have + + .. math:: + + (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} - \hat{D}^{-1} + \hat{A}^H (I + \hat{A} \hat{D}^{-1} \hat{A}^H)^{-1} \hat{A} + \hat{D}^{-1} + + which can be simplified to + + .. math:: + + (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} (I - \hat{A}^H + \hat{E}^{-1} \hat{A} \hat{D}^{-1}) + + by defining :math:`\hat{E} = I + \hat{A} \hat{D}^{-1} \hat{A}^H`. The + right hand side is much cheaper to compute because the only matrix + inversions involve :math:`\hat{D}`, which is diagonal, and + :math:`\hat{E}`, which is a weighted inner product of + :math:`\hat{A}^H` and :math:`\hat{A}`. + """ + + def __init__(self, A: ComposedLinearOperator, D: CircularConvolve): + r""" + Args: + A: Operator :math:`A`. + D: Operator :math:`D`. + """ + if not isinstance(A, ComposedLinearOperator): + raise TypeError( + f"Operator A is required to be a ComposedLinearOperator; got a {type(A)}." + ) + if not isinstance(A.A, Sum) or not isinstance(A.B, CircularConvolve): + raise TypeError( + "Operator A is required to be a composition of Sum and CircularConvolve" + f"linear operators; got a composition of {type(A.A)} and {type(A.B)}." + ) + + self.A = A + self.D = D + self.sum_axis = A.A.kwargs["axis"] + if not isinstance(self.sum_axis, int): + raise ValueError( + "Sum component of operator A must sum over a single axis of its input." + ) + self.fft_axes = A.B.x_fft_axes + self.real_result = is_real_dtype(D.input_dtype) + + Ahat = A.B.h_dft + Dhat = D.h_dft + self.AHEinv = Ahat.conj() / ( + 1.0 + snp.sum(Ahat * (Ahat.conj() / Dhat), axis=self.sum_axis, keepdims=True) + ) + + def solve(self, b: Array) -> Array: + r"""Solve the linear system. + + Solve the linear system with right hand side :math:`\mb{b}`. + + Args: + b: Array :math:`\mathbf{b}`. + + Returns: + Solution to the linear system. + """ + assert isinstance(self.A.B, CircularConvolve) + + Ahat = self.A.B.h_dft + Dhat = self.D.h_dft + bhat = snp.fft.fftn(b, axes=self.fft_axes) + xhat = ( + bhat - (self.AHEinv * (snp.sum(Ahat * bhat / Dhat, axis=self.sum_axis, keepdims=True))) + ) / Dhat + x = snp.fft.ifftn(xhat, axes=self.fft_axes) + if self.real_result: + x = x.real + + return x + + def accuracy(self, x: Array, b: Array) -> float: + r"""Compute solution relative residual. + + Args: + x: Array :math:`\mathbf{x}` (solution). + b: Array :math:`\mathbf{b}` (right hand side of linear system). + + Returns: + Relative residual of solution. + """ + return rel_res(self.A.gram_op(x) + self.D(x), b) diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index 282caa7d1..ce51196b6 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -374,6 +374,29 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): b = Dnew @ x snp.testing.assert_allclose(a, b, rtol=1e-5) + @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("input_shape1", input_shapes) + @pytest.mark.parametrize("input_shape2", input_shapes) + def test_matmul(self, input_shape1, input_shape2, diagonal_dtype): + + diagonal1, key = randn(input_shape1, dtype=diagonal_dtype, key=self.key) + diagonal2, key = randn(input_shape2, dtype=diagonal_dtype, key=key) + x, key = randn(input_shape1, dtype=diagonal_dtype, key=key) + + D1 = linop.Diagonal(diagonal=diagonal1) + D2 = linop.Diagonal(diagonal=diagonal2) + + if input_shape1 != input_shape2: + with pytest.raises(ValueError): + D3 = D1 @ D2 + else: + D3 = D1 @ D2 + assert isinstance(D3, linop.Diagonal) + a = D3 @ x + D4 = linop.Diagonal(diagonal1 * diagonal2) + b = D4 @ x + snp.testing.assert_allclose(a, b, rtol=1e-5) + @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op_mismatch(self, operator): diagonal_dtype = np.float32 @@ -418,6 +441,40 @@ def test_scalar_left(self, operator): np.testing.assert_allclose(scaled_D @ x, operator(D @ x, scalar), rtol=5e-5) + @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) + def test_gram_op(self, diagonal_dtype): + + input_shape = (7,) + diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) + + D1 = linop.Diagonal(diagonal=diagonal) + D2 = D1.gram_op + D3 = D1.H @ D1 + assert isinstance(D3, linop.Diagonal) + snp.testing.assert_allclose(D2.diagonal, D3.diagonal, rtol=1e-6) + + @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2]) + def test_norm(self, diagonal_dtype, ord): + + input_shape = (5,) + diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) + + D1 = linop.Diagonal(diagonal=diagonal) + D2 = snp.diag(diagonal) + n1 = D1.norm(ord=ord) + n2 = snp.linalg.norm(D2, ord=ord) + snp.testing.assert_allclose(n1, n2, rtol=1e-6) + + def test_norm_except(self): + + input_shape = (5,) + diagonal, key = randn(input_shape, dtype=np.float32, key=self.key) + + D = linop.Diagonal(diagonal=diagonal) + with pytest.raises(ValueError): + n = D.norm(ord=3) + def test_adj_lazy(): dtype = np.float32 diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index b100e3f40..b246c7416 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -5,7 +5,7 @@ import pytest import scico.numpy as snp -from scico import functional, linop, loss, metric, random +from scico import functional, linop, loss, metric, operator, random from scico.optimize import ADMM from scico.optimize.admm import ( CircularConvolveSolver, @@ -13,6 +13,7 @@ G0BlockCircularConvolveSolver, GenericSubproblemSolver, LinearSubproblemSolver, + MatrixSubproblemSolver, ) @@ -72,6 +73,37 @@ def callback(obj): with pytest.raises(ValueError): admm_.solve() + @pytest.mark.parametrize( + "solver", [LinearSubproblemSolver, MatrixSubproblemSolver, CircularConvolveSolver] + ) + def test_admm_aux(self, solver): + maxiter = 2 + 蟻 = 1e-1 + A = operator.Abs(self.y.shape) + f = loss.SquaredL2Loss(y=self.y, A=A) + g = functional.DnCNN() + C = linop.Identity(self.y.shape) + + with pytest.raises(TypeError): + admm_ = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[蟻], + maxiter=maxiter, + subproblem_solver=solver(), + ) + + with pytest.raises(TypeError): + admm_ = ADMM( + f=g, + g_list=[g], + C_list=[C], + rho_list=[蟻], + maxiter=maxiter, + subproblem_solver=solver(), + ) + class TestReal: def setup_method(self, method): @@ -204,7 +236,7 @@ def setup_method(self, method): self.grdA = lambda x: (饾浖 * Amx.T @ (W * Amx) + 位 * Bmx.T @ Bmx) @ x self.grdb = 饾浖 * Amx.T @ (W[:, 0] * y) - def test_admm_quadratic(self): + def test_admm_quadratic_linear(self): maxiter = 100 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) @@ -225,6 +257,27 @@ def test_admm_quadratic(self): x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 + def test_admm_quadratic_matrix(self): + maxiter = 50 + 蟻 = 1e0 + A = linop.MatrixOperator(self.Amx) + f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.饾浖 / 2.0) + g_list = [(self.位 / 2) * functional.SquaredL2Norm()] + C_list = [linop.MatrixOperator(self.Bmx)] + rho_list = [蟻] + admm_ = ADMM( + f=f, + g_list=g_list, + C_list=C_list, + rho_list=rho_list, + maxiter=maxiter, + itstat_options={"display": False}, + x0=A.adj(self.y), + subproblem_solver=MatrixSubproblemSolver(), + ) + x = admm_.solve() + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + class TestComplex: def setup_method(self, method): @@ -234,12 +287,12 @@ def setup_method(self, method): # Set up arrays for problem argmin (饾浖/2) ||A x - y||_2^2 + (位/2) ||B x||_2^2 Amx, key = random.randn((MA, N), dtype=np.complex64, key=None) Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) - y = np.random.randn(MA) + y, key = random.randn((MA,), dtype=np.complex64, key=key) 饾浖 = 1.0 / 3.0 位 = 1e0 self.Amx = Amx self.Bmx = Bmx - self.y = jax.device_put(y) + self.y = y self.饾浖 = 饾浖 self.位 = 位 # Solution of problem is given by linear system (饾浖 A^T A + 位 B^T B) x = A^T y @@ -267,7 +320,7 @@ def test_admm_generic(self): x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3 - def test_admm_quadratic(self): + def test_admm_quadratic_linear(self): maxiter = 50 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) @@ -290,6 +343,27 @@ def test_admm_quadratic(self): x = admm_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 + def test_admm_quadratic_matrix(self): + maxiter = 50 + 蟻 = 1e0 + A = linop.MatrixOperator(self.Amx) + f = loss.SquaredL2Loss(y=self.y, A=A, scale=self.饾浖 / 2.0) + g_list = [(self.位 / 2) * functional.SquaredL2Norm()] + C_list = [linop.MatrixOperator(self.Bmx)] + rho_list = [蟻] + admm_ = ADMM( + f=f, + g_list=g_list, + C_list=C_list, + rho_list=rho_list, + maxiter=maxiter, + itstat_options={"display": False}, + x0=A.adj(self.y), + subproblem_solver=MatrixSubproblemSolver(), + ) + x = admm_.solve() + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + class TestCircularConvolveSolve: def setup_method(self, method): @@ -369,7 +443,7 @@ def test_fblock_init(self): itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(), ) - with pytest.raises(ValueError): + with pytest.raises(TypeError): slvr = ADMM( f=loss.PoissonLoss(y=self.y), g_list=self.g_list, @@ -378,7 +452,7 @@ def test_fblock_init(self): itstat_options={"display": False}, subproblem_solver=FBlockCircularConvolveSolver(), ) - with pytest.raises(ValueError): + with pytest.raises(TypeError): slvr = ADMM( f=loss.SquaredL2Loss(y=self.y, A=self.A.A), g_list=self.g_list, @@ -398,7 +472,7 @@ def test_g0block_init(self): itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(), ) - with pytest.raises(ValueError): + with pytest.raises(TypeError): slvr = ADMM( f=functional.ZeroFunctional(), g_list=[loss.PoissonLoss(y=self.y)], @@ -407,7 +481,7 @@ def test_g0block_init(self): itstat_options={"display": False}, subproblem_solver=G0BlockCircularConvolveSolver(), ) - with pytest.raises(ValueError): + with pytest.raises(TypeError): slvr = ADMM( f=functional.ZeroFunctional(), g_list=[loss.SquaredL2Loss(y=self.y)] + self.g_list, diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index abe3745a2..d5b179b62 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -6,7 +6,7 @@ import pytest import scico.numpy as snp -from scico import linop, random, solver +from scico import linop, metric, random, solver class TestSet: @@ -294,3 +294,78 @@ def test_golden(): f = lambda x, c: (x - c) ** 2 x = solver.golden(f, -snp.abs(c) - 1, snp.abs(c) + 1, args=(c,), xtol=1e-5) assert snp.max(snp.abs(x - c)) <= 1e-5 + + +@pytest.mark.parametrize("cho_factor", [True, False]) +@pytest.mark.parametrize("wide", [True, False]) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("alpha", [1e-1, 1e1]) +def test_solve_atai(cho_factor, wide, weighted, alpha): + A, key = random.randn((5, 8), dtype=snp.float32) + if wide: + x0, key = random.randn((8,), key=key) + else: + A = A.T + x0, key = random.randn((5,), key=key) + + if weighted: + W, key = random.randn((A.shape[0],), key=key) + W = snp.abs(W) + Wa = W[:, snp.newaxis] + else: + W = None + Wa = snp.array([1.0])[:, snp.newaxis] + + D = alpha * snp.ones((A.shape[1],)) + ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1]) + b = ATAD @ x0 + slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor) + x1 = slv.solve(b) + assert metric.rel_res(x0, x1) < 5e-5 + + +@pytest.mark.parametrize("cho_factor", [True, False]) +@pytest.mark.parametrize("wide", [True, False]) +@pytest.mark.parametrize("alpha", [1e-1, 1e1]) +def test_solve_aati(cho_factor, wide, alpha): + A, key = random.randn((5, 8), dtype=snp.float32) + if wide: + x0, key = random.randn((5,), key=key) + else: + A = A.T + x0, key = random.randn((8,), key=key) + + D = alpha * snp.ones((A.shape[0],)) + AATD = A @ A.T + alpha * snp.identity(A.shape[0]) + b = AATD @ x0 + slv = solver.ATADSolver(A.T, D) + x1 = slv.solve(b) + assert metric.rel_res(x0, x1) < 5e-5 + + +@pytest.mark.parametrize("cho_factor", [True, False]) +@pytest.mark.parametrize("wide", [True, False]) +@pytest.mark.parametrize("vector", [True, False]) +def test_solve_atad(cho_factor, wide, vector): + A, key = random.randn((5, 8), dtype=snp.float32) + if wide: + D, key = random.randn((8,), key=key) + if vector: + x0, key = random.randn((8,), key=key) + else: + x0, key = random.randn((8, 3), key=key) + else: + A = A.T + D, key = random.randn((5,), key=key) + if vector: + x0, key = random.randn((5,), key=key) + else: + x0, key = random.randn((5, 3), key=key) + + D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU + ATAD = A.T @ A + snp.diag(D) + b = ATAD @ x0 + slv = solver.ATADSolver(A, D, cho_factor=cho_factor) + x1 = slv.solve(b) + assert metric.rel_res(x0, x1) < 5e-5 + assert slv.accuracy(x1, b) < 5e-5 diff --git a/scico/test/test_util.py b/scico/test/test_util.py index 5b58c0bf5..5bbdf002d 100644 --- a/scico/test/test_util.py +++ b/scico/test/test_util.py @@ -8,7 +8,32 @@ import pytest import scico.numpy as snp -from scico.util import ContextTimer, Timer, check_for_tracer, partial, url_get +from scico.util import ( + ContextTimer, + Timer, + check_for_tracer, + partial, + rgetattr, + rsetattr, + url_get, +) + + +def test_rattr(): + class A: + class B: + c = 0 + + b = B() + + a = A() + rsetattr(a, "b.c", 1) + assert rgetattr(a, "b.c") == 1 + + assert rgetattr(a, "c.d", 10) == 10 + + with pytest.raises(AttributeError): + assert rgetattr(a, "c.d") def test_partial_pos(): diff --git a/scico/util.py b/scico/util.py index 73982fd0b..d57c8efc6 100644 --- a/scico/util.py +++ b/scico/util.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers +# Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -14,7 +14,7 @@ import socket import urllib.error as urlerror import urllib.request as urlrequest -from functools import wraps +from functools import reduce, wraps from timeit import default_timer as timer from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -23,6 +23,44 @@ from jax.interpreters.partial_eval import DynamicJaxprTracer +def rgetattr(obj: object, name: str, default: Optional[Any] = None) -> Any: + """Recursive version of :func:`getattr`. + + Args: + obj: Object with the attribute to be accessed. + name: Path to object in with components delimited by a "." + character. + default: Default value to be returned if the attribute does not + exist. + + Returns: + Attribute value of default if attribute does not exist. + """ + + try: + return reduce(getattr, name.split("."), obj) + except AttributeError as e: + if default is not None: + return default + else: + raise e + + +def rsetattr(obj: object, name: str, value: Any): + """Recursive version of :func:`setattr`. + + Args: + obj: Object with the attribute to be set. + name: Path to object in with components delimited by a "." + character. + value: Value to which the attribute is to be set. + """ + + # See goo.gl/BVJ7MN + path = name.split(".") + setattr(reduce(getattr, path[:-1], obj), path[-1], value) + + def partial(func: Callable, indices: Sequence, *fixargs: Any, **fixkwargs: Any) -> Callable: """Flexible partial function creation.