Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs fixes #21

Merged
merged 8 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ class LinearSubproblemSolver(SubproblemSolver):
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A x}_W^2 +
\sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,

where :math:`W` is the weighting :class:`.LinearOperator` from the :class:`.WeightedSquaredL2Loss`
instance. This update step reduces to the solution of the linear system
where :math:`W` is the weighting :class:`.LinearOperator` from the
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the solution
of the linear system

.. math::
\left(A^* W A + \sum_{i=1}^N \rho_i C_i^* C_i \right) \mb{x}^{(k+1)} = \;
Expand Down Expand Up @@ -173,8 +174,8 @@ def internal_init(self, admm):

super().internal_init(admm)

# set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve
# use reduce as the initialization of this sum is messy otherwise
# Set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve
# Use reduce as the initialization of this sum is messy otherwise
lhs_op = reduce(
lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
)
Expand Down Expand Up @@ -295,21 +296,21 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:


class ADMM:
r"""Basic Alternating Direction Method of Multipliers (ADMM)
algorithm :cite:`boyd-2010-distributed`.
r"""Basic Alternating Direction Method of Multipliers (ADMM) algorithm
:cite:`boyd-2010-distributed`.

|

Solve an optimization problem of the form

.. math::
\argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}),
\argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;,

where :math:`f` is an instance of :class:`.Loss`, the :math:`g_i` are :class:`.Functional`,
and the :math:`C_i` are :class:`.LinearOperator`.

The optimization problem is solved by introducing the splitting :math:`\mb{z}_i = C_i \mb{x}`
and solving
The optimization problem is solved by introducing the splitting :math:`\mb{z}_i =
C_i \mb{x}` and solving

.. math::
\argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;
Expand Down Expand Up @@ -423,7 +424,7 @@ def itstat_func(i, obj):
)

else:
# at least one 'g' can't be evaluated, so drop objective from the default itstat
# At least one 'g' can't be evaluated, so drop objective from the default itstat
itstat_dict = {"Iter": "%d", "Primal Rsdl": "%8.3e", "Dual Rsdl": "%8.3e"}

def itstat_func(i, admm):
Expand Down Expand Up @@ -502,7 +503,7 @@ def norm_dual_residual(self) -> float:
r"""Compute the :math:`\ell_2` norm of the dual residual.

.. math::
\left(\sum_{i=1}^N \norm{\mb{z}^{(k)} - \mb{z}^{(k-1)}}_2^2\right)^{1/2}
\left(\sum_{i=1}^N \norm{\mb{z}^{(k)}_i - \mb{z}^{(k-1)}_i}_2^2\right)^{1/2}

Returns:
Current value of dual residual
Expand All @@ -514,12 +515,12 @@ def norm_dual_residual(self) -> float:
return snp.sqrt(out)

def z_init(self, x0: Union[JaxArray, BlockArray]):
r"""Initialize auxiliary variables :math:`\mb{z}`.
r"""Initialize auxiliary variables :math:`\mb{z}_i`.

Initializes to

.. math::
\mb{z}_i = C_i \mb{x}_0
\mb{z}_i = C_i \mb{x}^{(0)}

:code:`z_list` and :code:`z_list_old` are initialized to the same value.

Expand All @@ -531,12 +532,12 @@ def z_init(self, x0: Union[JaxArray, BlockArray]):
return z_list, z_list_old

def u_init(self, x0: Union[JaxArray, BlockArray]):
r"""Initialize scaled Lagrange multipliers :math:`\mb{u}`.
r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`.

Initializes to

.. math::
\mb{u}_i = C_i \mb{x}_0
\mb{u}_i = C_i \mb{x}^{(0)}


Args:
Expand All @@ -556,8 +557,8 @@ def x_step(self, x):
return self.subproblem_solver.solve(x)

def z_and_u_step(self, u_list, z_list):
r""" Update the auxiliary variables :math:`\mb{z}` and scaled Lagrange multipliers
:math:`\mb{u}`.
r"""Update the auxiliary variables :math:`\mb{z}_i` and scaled Lagrange multipliers
:math:`\mb{u}_i`.

The auxiliary variables are updated according to

Expand All @@ -576,15 +577,15 @@ def z_and_u_step(self, u_list, z_list):
"""
z_list_old = z_list.copy()

# unpack the arrays that will be changing to prevent side-effects
# Unpack the arrays that will be changing to prevent side-effects
z_list = self.z_list
u_list = self.u_list

for i, (rhoi, fi, Ci, zi, ui) in enumerate(
for i, (rhoi, gi, Ci, zi, ui) in enumerate(
zip(self.rho_list, self.g_list, self.C_list, z_list, u_list)
):
Cix = Ci(self.x)
zi = fi.prox(Cix + ui, 1 / rhoi)
zi = gi.prox(Cix + ui, 1 / rhoi)
ui = ui + Cix - zi
z_list[i] = zi
u_list[i] = ui
Expand All @@ -594,7 +595,6 @@ def step(self):
"""Perform a single ADMM iteration.

Equivalent to calling :meth:`.x_step` followed by :meth:`.z_and_u_step`.

"""
self.x = self.x_step(self.x)
self.u_list, self.z_list, self.z_list_old = self.z_and_u_step(self.u_list, self.z_list)
Expand Down
2 changes: 1 addition & 1 deletion scico/objax.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,5 @@ def __call__(self, x: JaxArray, training: bool) -> JaxArray:
x = ly(x, training)

x = self.post_conv(x)
# residual-like output
# Residual-like output
return base - x
6 changes: 3 additions & 3 deletions scico/pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def update(self, v: Union[JaxArray, BlockArray]) -> float:

if self.xprev is None:
# Solution and gradient of previous iterate are required.
# For first iteration these variables are stored and current estimation is returned.
# For first iteration these variables are stored and current estimate is returned.
self.xprev = v
self.gradprev = self.pgm.f.grad(self.xprev)
L = self.pgm.L
Expand Down Expand Up @@ -197,7 +197,7 @@ def update(self, v: Union[JaxArray, BlockArray]) -> float:
# Store current state and gradient for next update.
self.xprev = v
self.gradprev = gradv
# Store current estimations of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2).
# Store current estimates of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2).
self.Lbb1prev = Lbb1
self.Lbb2prev = Lbb2

Expand Down Expand Up @@ -333,7 +333,7 @@ class PGM:

The function :math:`f` must be smooth and :math:`g` must have a defined prox.

Uses helper :class:`StepSize` to provide an estimation of the Lipschitz constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`.
Uses helper :class:`StepSize` to provide an estimate of the Lipschitz constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`.
"""

def __init__(
Expand Down