diff --git a/scico/admm.py b/scico/admm.py index 8a8cb234d..939e026d8 100644 --- a/scico/admm.py +++ b/scico/admm.py @@ -22,7 +22,7 @@ from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator -from scico.loss import Loss, SquaredL2Loss, WeightedSquaredL2Loss +from scico.loss import SquaredL2Loss, WeightedSquaredL2Loss from scico.math import is_real_dtype from scico.numpy.linalg import norm from scico.solver import cg as scico_cg @@ -125,8 +125,8 @@ class LinearSubproblemSolver(SubproblemSolver): \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, where :math:`W` is the weighting :class:`.LinearOperator` from the - :class:`.WeightedSquaredL2Loss` instance. This update step reduces to the solution - of the linear system + :class:`.WeightedSquaredL2Loss` instance. This update step reduces to the + solution of the linear system .. math:: \left(A^* W A + \sum_{i=1}^N \rho_i C_i^* C_i \right) \mb{x}^{(k+1)} = \; @@ -309,11 +309,11 @@ class ADMM: .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;, - where :math:`f` is an instance of :class:`.Loss`, the :math:`g_i` are :class:`.Functional`, + where :math:`f` and the :math:`g_i` are instances of :class:`.Functional`, and the :math:`C_i` are :class:`.LinearOperator`. - The optimization problem is solved by introducing the splitting :math:`\mb{z}_i = - C_i \mb{x}` and solving + The optimization problem is solved by introducing the splitting + :math:`\mb{z}_i = C_i \mb{x}` and solving .. math:: \argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \; @@ -338,7 +338,7 @@ class ADMM: Attributes: - f (:class:`.Loss`): Loss function + f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`) g_list (list of :class:`.Functional`): List of :math:`g_i` functionals. Must be same length as :code:`C_list` and :code:`rho_list`. C_list (list of :class:`.LinearOperator`): List of :math:`C_i` operators. @@ -349,8 +349,9 @@ class ADMM: Must be same length as :code:`C_list` and :code:`g_list`. u_list (list of array-like): List of scaled Lagrange multipliers :math:`\mb{u}_i` at current iteration. - x (array-like): Solution. - subproblem_solver (:class:`.SubproblemSolver`): Solver for :math:`\mb{x}`-update step. + x (array-like): Solution + subproblem_solver (:class:`.SubproblemSolver`): Solver for + :math:`\mb{x}`-update step. z_list (list of array-like): List of auxiliary variables :math:`\mb{z}_i` at current iteration. z_list_old (list of array-like): List of auxiliary variables :math:`\mb{z}_i` @@ -359,7 +360,7 @@ class ADMM: def __init__( self, - f: Loss, + f: Functional, g_list: List[Functional], C_list: List[LinearOperator], rho_list: List[float], @@ -372,23 +373,28 @@ def __init__( r"""Initialize an :class:`ADMM` object. Args: - f : Loss function - g_list : List of :math:`g_i` - functionals. Must be same length as :code:`C_list` and :code:`rho_list` + f : Functional :math:`f` (usually a loss function) + g_list : List of :math:`g_i` functionals. Must be same length + as :code:`C_list` and :code:`rho_list` C_list : List of :math:`C_i` operators rho_list : List of :math:`\rho_i` penalty parameters. Must be same length as :code:`C_list` and :code:`g_list` - x0 : Starting point for :math:`\mb{x}`. If None, defaults to an array of zeros. + x0 : Starting point for :math:`\mb{x}`. If None, defaults to + an array of zeros. maxiter : Number of ADMM outer-loop iterations. Default: 100. - subproblem_solver : Solver for :math:`\mb{x}`-update step. Defaults to ``None``, which - implies use of an instance of :class:`GenericSubproblemSolver`. - verbose: Flag indicating whether iteration statistics should be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` is a dict suitable - for passing to the `fields` argument of the :class:`.diagnostics.IterationStats` - initializer, and `insertfunc` is a function with two parameters, an integer - and an ADMM object, responsible for constructing a tuple ready for insertion into - the :class:`.diagnostics.IterationStats` object. If None, default values are - used for the tuple components. + subproblem_solver : Solver for :math:`\mb{x}`-update step. + Defaults to ``None``, which implies use of an instance of + :class:`GenericSubproblemSolver`. + verbose: Flag indicating whether iteration statistics should + be displayed. + itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` + is a dict suitable for passing to the `fields` argument + of the :class:`.diagnostics.IterationStats` initializer, + and `insertfunc` is a function with two parameters, an + integer and an ADMM object, responsible for constructing + a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If None, + default values are used for the tuple components. """ N = len(g_list) if len(C_list) != N: @@ -396,7 +402,7 @@ def __init__( if len(rho_list) != N: raise Exception(f"len(rho_list)={len(rho_list)} not equal to len(g_list)={N}") - self.f: Loss = f + self.f: Functional = f self.g_list: List[Functional] = g_list self.C_list: List[LinearOperator] = C_list self.rho_list: List[float] = rho_list