Skip to content

Commit

Permalink
Change typing to allow f to be any Functional rather than requiring i…
Browse files Browse the repository at this point in the history
…t to be a Loss (#25)
  • Loading branch information
bwohlberg authored Oct 8, 2021
1 parent c310e7f commit 8803dc6
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from scico.diagnostics import IterationStats
from scico.functional import Functional
from scico.linop import CircularConvolve, Identity, LinearOperator
from scico.loss import Loss, SquaredL2Loss, WeightedSquaredL2Loss
from scico.loss import SquaredL2Loss, WeightedSquaredL2Loss
from scico.math import is_real_dtype
from scico.numpy.linalg import norm
from scico.solver import cg as scico_cg
Expand Down Expand Up @@ -125,8 +125,8 @@ class LinearSubproblemSolver(SubproblemSolver):
\sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,
where :math:`W` is the weighting :class:`.LinearOperator` from the
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the solution
of the linear system
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the
solution of the linear system
.. math::
\left(A^* W A + \sum_{i=1}^N \rho_i C_i^* C_i \right) \mb{x}^{(k+1)} = \;
Expand Down Expand Up @@ -309,11 +309,11 @@ class ADMM:
.. math::
\argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;,
where :math:`f` is an instance of :class:`.Loss`, the :math:`g_i` are :class:`.Functional`,
where :math:`f` and the :math:`g_i` are instances of :class:`.Functional`,
and the :math:`C_i` are :class:`.LinearOperator`.
The optimization problem is solved by introducing the splitting :math:`\mb{z}_i =
C_i \mb{x}` and solving
The optimization problem is solved by introducing the splitting
:math:`\mb{z}_i = C_i \mb{x}` and solving
.. math::
\argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;
Expand All @@ -338,7 +338,7 @@ class ADMM:
Attributes:
f (:class:`.Loss`): Loss function
f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`)
g_list (list of :class:`.Functional`): List of :math:`g_i`
functionals. Must be same length as :code:`C_list` and :code:`rho_list`.
C_list (list of :class:`.LinearOperator`): List of :math:`C_i` operators.
Expand All @@ -349,8 +349,9 @@ class ADMM:
Must be same length as :code:`C_list` and :code:`g_list`.
u_list (list of array-like): List of scaled Lagrange multipliers
:math:`\mb{u}_i` at current iteration.
x (array-like): Solution.
subproblem_solver (:class:`.SubproblemSolver`): Solver for :math:`\mb{x}`-update step.
x (array-like): Solution
subproblem_solver (:class:`.SubproblemSolver`): Solver for
:math:`\mb{x}`-update step.
z_list (list of array-like): List of auxiliary variables :math:`\mb{z}_i`
at current iteration.
z_list_old (list of array-like): List of auxiliary variables :math:`\mb{z}_i`
Expand All @@ -359,7 +360,7 @@ class ADMM:

def __init__(
self,
f: Loss,
f: Functional,
g_list: List[Functional],
C_list: List[LinearOperator],
rho_list: List[float],
Expand All @@ -372,31 +373,36 @@ def __init__(
r"""Initialize an :class:`ADMM` object.
Args:
f : Loss function
g_list : List of :math:`g_i`
functionals. Must be same length as :code:`C_list` and :code:`rho_list`
f : Functional :math:`f` (usually a loss function)
g_list : List of :math:`g_i` functionals. Must be same length
as :code:`C_list` and :code:`rho_list`
C_list : List of :math:`C_i` operators
rho_list : List of :math:`\rho_i` penalty parameters.
Must be same length as :code:`C_list` and :code:`g_list`
x0 : Starting point for :math:`\mb{x}`. If None, defaults to an array of zeros.
x0 : Starting point for :math:`\mb{x}`. If None, defaults to
an array of zeros.
maxiter : Number of ADMM outer-loop iterations. Default: 100.
subproblem_solver : Solver for :math:`\mb{x}`-update step. Defaults to ``None``, which
implies use of an instance of :class:`GenericSubproblemSolver`.
verbose: Flag indicating whether iteration statistics should be displayed.
itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` is a dict suitable
for passing to the `fields` argument of the :class:`.diagnostics.IterationStats`
initializer, and `insertfunc` is a function with two parameters, an integer
and an ADMM object, responsible for constructing a tuple ready for insertion into
the :class:`.diagnostics.IterationStats` object. If None, default values are
used for the tuple components.
subproblem_solver : Solver for :math:`\mb{x}`-update step.
Defaults to ``None``, which implies use of an instance of
:class:`GenericSubproblemSolver`.
verbose: Flag indicating whether iteration statistics should
be displayed.
itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec`
is a dict suitable for passing to the `fields` argument
of the :class:`.diagnostics.IterationStats` initializer,
and `insertfunc` is a function with two parameters, an
integer and an ADMM object, responsible for constructing
a tuple ready for insertion into the
:class:`.diagnostics.IterationStats` object. If None,
default values are used for the tuple components.
"""
N = len(g_list)
if len(C_list) != N:
raise Exception(f"len(C_list)={len(C_list)} not equal to len(g_list)={N}")
if len(rho_list) != N:
raise Exception(f"len(rho_list)={len(rho_list)} not equal to len(g_list)={N}")

self.f: Loss = f
self.f: Functional = f
self.g_list: List[Functional] = g_list
self.C_list: List[LinearOperator] = C_list
self.rho_list: List[float] = rho_list
Expand Down

0 comments on commit 8803dc6

Please sign in to comment.