Skip to content

Commit

Permalink
Add ADMM subproblem solver for matrix operators (#426)
Browse files Browse the repository at this point in the history
* Add utility functions

* Add missing typing

* Add linear system solver class

* Move common code in block circulant subproblem solvers to a separate class

* Clean up

* Clean up

* Docs corrections

* Minor import changes

* Clean up

* Change exception type

* Trivial change

* Add a note to docs

* Fix exception types

* Minor docs fixes

* Extend function to support diagonal term

* Improve docs

* Fix support for 2d X and B

* Add support for weight matrix

* Resolve typing errors

* Add ADMM subproblem solver for matrix operators

* Add support for complex problems

* Minor docs improvement

* Add some tests

* Minor coding style correction

* Add specializations of __matmul__, T, H, and conj method for Diagonal linops

* Switch example script to use MatrixSubproblemSolver

* Docs fix

* Add norm method for Diagonal linop

* Typing fix

* Ensure common dtype

* Fix test function name and extend tests

* Allow small coverage drop

* Extend a test

* Minor docs fix

* Add some tests

* Minor writing style fix

* Rename classes per PR review comment
  • Loading branch information
bwohlberg authored Jun 16, 2023
1 parent 894f8b3 commit 5c7f6ba
Show file tree
Hide file tree
Showing 13 changed files with 881 additions and 143 deletions.
8 changes: 8 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
coverage:
precision: 2
round: nearest
range: "80...100"

status:
project:
default:
target: auto
threshold: 0.05%
patch: false
4 changes: 2 additions & 2 deletions examples/scripts/sparsecode_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import jax

from scico import functional, linop, loss, plot
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.optimize.admm import ADMM, MatrixSubproblemSolver
from scico.util import device_info

"""
Expand Down Expand Up @@ -67,7 +67,7 @@
rho_list=rho_list,
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(),
subproblem_solver=MatrixSubproblemSolver(),
itstat_options={"display": True, "period": 10},
)

Expand Down
64 changes: 61 additions & 3 deletions scico/linop/_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(
r"""
Args:
diagonal: Diagonal elements of this :class:`LinearOperator`.
input_shape: Shape of input array. By default, equal to
input_shape: Shape of input array. By default, equal to
`diagonal.shape`, but may also be set to a shape that is
broadcast-compatiable with `diagonal.shape`.
broadcast-compatible with `diagonal.shape`.
input_dtype: `dtype` of input argument. The default,
``None``, means `diagonal.dtype`.
"""
Expand All @@ -64,7 +64,7 @@ def __init__(
elif isinstance(diagonal, BlockArray):
raise ValueError("Parameter diagonal was a BlockArray but input_shape was not nested.")
else:
raise ValueError("Parameter diagonal was a not BlockArray but input_shape was nested.")
raise ValueError("Parameter diagonal was not a BlockArray but input_shape was nested.")

super().__init__(
input_shape=input_shape,
Expand All @@ -77,6 +77,29 @@ def __init__(
def _eval(self, x):
return x * self.diagonal

@property
def T(self) -> Diagonal:
"""Transpose of this :class:`Diagonal`."""
return self

def conj(self) -> Diagonal:
"""Complex conjugate of this :class:`Diagonal`."""
return Diagonal(diagonal=self.diagonal.conj())

@property
def H(self) -> Diagonal:
"""Hermitian transpose of this :class:`Diagonal`."""
return self.conj()

@property
def gram_op(self) -> Diagonal:
"""Gram operator of this :class:`Diagonal`.
Return a new :class:`Diagonal` :code:`G` such that
:code:`G(x) = A.adj(A(x)))`.
"""
return Diagonal(diagonal=self.diagonal.conj() * self.diagonal)

@partial(_wrap_add_sub, op=operator.add)
def __add__(self, other):
if self.diagonal.shape == other.diagonal.shape:
Expand All @@ -101,6 +124,41 @@ def __rmul__(self, scalar):
def __truediv__(self, scalar):
return Diagonal(diagonal=self.diagonal / scalar)

def __matmul__(self, other):
# self @ other
if isinstance(other, Diagonal):
if self.shape == other.shape:
return Diagonal(diagonal=self.diagonal * other.diagonal)

raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.")

else:
return self(other)

def norm(self, ord=None): # pylint: disable=W0622
"""Compute the matrix norm of the diagonal operator.
Valid values of `ord` and the corresponding norm definition
are those listed under "norm for matrices" in the
:func:`scico.numpy.linalg.norm` documentation.
"""
ordfunc = {
"fro": lambda x: snp.linalg.norm(x),
"nuc": lambda x: snp.sum(snp.abs(x)),
-snp.inf: lambda x: snp.abs(x).min(),
snp.inf: lambda x: snp.abs(x).max(),
}
mord = ord
if mord is None:
mord = "fro"
elif mord in (-1, -2):
mord = -snp.inf
elif mord in (1, 2):
mord = snp.inf
if mord not in ordfunc:
raise ValueError(f"Invalid value {ord} for parameter ord.")
return ordfunc[mord](self.diagonal)


class Identity(Diagonal):
"""Identity operator."""
Expand Down
4 changes: 2 additions & 2 deletions scico/linop/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __mul__(self, other):
raise TypeError(f"Operation __mul__ not defined between {type(self)} and {type(other)}.")

def __rmul__(self, other):
# Multiplication is commutative
# multiplication is commutative
return self * other

def __truediv__(self, other):
Expand Down Expand Up @@ -255,6 +255,6 @@ def gram_op(self):
def norm(self, ord=None, axis=None, keepdims=False): # pylint: disable=W0622
"""Compute the norm of the dense matrix `self.A`.
Call :func:`scico.numpy.norm` on the dense matrix `self.A`.
Call :func:`scico.numpy.linalg.norm` on the dense matrix `self.A`.
"""
return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims)
3 changes: 2 additions & 1 deletion scico/optimize/_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
G0BlockCircularConvolveSolver,
GenericSubproblemSolver,
LinearSubproblemSolver,
MatrixSubproblemSolver,
SubproblemSolver,
)
from ._common import Optimizer
Expand Down Expand Up @@ -189,7 +190,7 @@ def _itstat_extra_fields(self):
)
elif (
type(self.subproblem_solver)
in [FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]
in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]
and self.subproblem_solver.check_solve
):
itstat_fields.update({"Slv Res": "%9.3e"})
Expand Down
Loading

0 comments on commit 5c7f6ba

Please sign in to comment.