Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a new solver for Proximal Averaged Projected Gradient Method #466

Closed
wants to merge 34 commits into from
Closed
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1b7f583
Added a new Functional for TV Norm implementing its proximal operator…
Oct 5, 2023
9d1d73a
added checks for input shape in TV2DNorm
Oct 5, 2023
877df4c
fixed lint errors, changed required argument to default in TV2DNorm, …
Oct 5, 2023
4947382
some unsaved changes from last commit
Oct 5, 2023
4375ddf
some unsaved changes from last commit
Oct 5, 2023
71fc636
newline at end of file error
Oct 5, 2023
c62da28
sort imports lint error
Oct 5, 2023
98ce989
removed the default shape parameter from TV2DNorm
Oct 6, 2023
096d1c9
Some docs edits
bwohlberg Oct 9, 2023
c2e1de5
Disable BlockArray tests on TV2DNorm
bwohlberg Oct 9, 2023
3a0cdb0
Fix black formatting
bwohlberg Oct 9, 2023
605d11b
updated the TV norm logic to apply shrinkage to only the difference o…
Oct 11, 2023
a7e82ba
Merge branch 'main' into tv-norm
bwohlberg Oct 31, 2023
c8efe90
Implementation supporting arbitrary dimensional inputs
bwohlberg Nov 2, 2023
b5e8fc9
Merge branch 'main' into tv-norm
bwohlberg Nov 3, 2023
2654882
Merge branch 'tv-norm' into tv-norm-alt-ver
bwohlberg Nov 3, 2023
ec8686e
Add a test
bwohlberg Nov 3, 2023
4f2f189
Minor changes
bwohlberg Nov 3, 2023
b7427f7
New implementation of TV norm and approximage prox
bwohlberg Nov 3, 2023
c0c9633
Clean up
bwohlberg Nov 3, 2023
f251c60
Typo fix
bwohlberg Nov 3, 2023
feb4b77
Minor change
bwohlberg Nov 4, 2023
7fe98b9
Add change log entry
bwohlberg Nov 4, 2023
2963523
Merge pull request #2 from shnaqvi/tv-norm-alt-ver
shnaqvi Nov 5, 2023
ded11f8
Resolve typing errors
bwohlberg Nov 5, 2023
c760e45
Resolve some oversights and issues arising when 64 bit floats enabled
bwohlberg Nov 5, 2023
dafd626
Standardise code formatting
bwohlberg Nov 5, 2023
2949931
Standardise code formatting
bwohlberg Nov 5, 2023
6a654ec
Standardise code formatting
bwohlberg Nov 5, 2023
3b7f75b
Apply skipped pre-commit
bwohlberg Nov 5, 2023
64228e8
added a new solver for solving composite prior minimization problem, …
Nov 8, 2023
de76dec
Fixed typos that were causing the lint and mypy tests to fail.
Nov 8, 2023
7d7df69
stylistic sugar changes for lint tests
Nov 8, 2023
d694f01
Merge branch 'main' into pa_pgm
shnaqvi Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions scico/optimize/_papgm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Proximal Averaged Accelerated Projected Gradient Method."""

from typing import List, Optional, Tuple, Union

import scico.numpy as snp
from scico.functional import Functional
from scico.loss import Loss
from scico.numpy import Array, BlockArray

from ._common import Optimizer


class AcceleratedPAPGM(Optimizer):
r"""Accelerated Proximal Averaged Projected Gradient Method (AcceleratedPAPGM) base class.

Minimize a function of the form :math:`f(\mb{x}) + \sum_{i=1}^N \rho_i g_i(\mb{x})`,

where :math:`f` and the :math:`g` are instances of :class:`.Functional`,
`rho_i` are positive and non-zero and sum upto 1.
This modifies FISTA to handle the case of composite prior minimization.
:cite:`yaoliang-2013-nips`.

"""

def __init__(
self,
f: Union[Loss, Functional],
g_list: List[Functional],
rho_list: List[float],
L0: float,
x0: Union[Array, BlockArray],
**kwargs,
):
r"""
Args:
f: (:class:`.Functional`): Functional :math:`f` (usually a
:class:`.Loss`)
g_list: (list of :class:`.Functional`): List of :math:`g_i`
functionals. Must be same length as :code:`rho_list`.
rho_list: (list of scalars): List of :math:`\rho_i` penalty
parameters. Must be same length as :code:`g_list` and sum to 1.
L0: (float): Initial estimate of Lipschitz constant of f.
x0: (array-like): Starting point for :math:`\mb{x}`.
**kwargs: Additional optional parameters handled by
initializer of base class :class:`.Optimizer`.
"""
self.f: Union[Loss, Functional] = f
self.g_list: List[Functional] = g_list
self.rho_list: List[float] = rho_list
self.x: Union[Array, BlockArray] = x0
self.fixed_point_residual: float = snp.inf
self.v: Union[Array, BlockArray] = x0
self.t: float = 1.0
self.L: float = L0

super().__init__(**kwargs)

def step(self):
"""Take a single AcceleratedPAPGM step."""
assert snp.sum(snp.array(self.rho_list)) == 1
assert snp.all(snp.array([rho >= 0 for rho in self.rho_list]))

x_old = self.x
z = self.v - 1.0 / self.L * self.f.grad(self.v)

self.fixed_point_residual = 0
self.x = snp.zeros_like(z)
for gi, rhoi in zip(self.g_list, self.rho_list):
self.x += rhoi * gi.prox(z, 1.0 / self.L)
self.fixed_point_residual += snp.linalg.norm(self.x - self.v)

t_old = self.t
self.t = 0.5 * (1 + snp.sqrt(1 + 4 * t_old**2))
self.v = self.x + ((t_old - 1) / self.t) * (self.x - x_old)

def _working_vars_finite(self) -> bool:
"""Determine where ``NaN`` of ``Inf`` encountered in solve.

Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in
a solver working variable.
"""
return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.v))

def minimizer(self):
"""Return current estimate of the functional mimimizer."""
return self.x

def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float:
r"""Evaluate the objective function

.. math::
f(\mb{x}) + \sum_{i=1}^N g_i(\mb{x}_i) \;.

Args:
x: Point at which to evaluate objective function. If ``None``,
the objective is evaluated at the current iterate
:code:`self.x`.

Returns:
Value of the objective function.
"""
if x is None:
x = self.x
out = 0.0
if self.f:
out += self.f(x)
for gi, rhoi in zip(self.g_list, self.rho_list):
out += rhoi * gi(x)
return out

def _objective_evaluatable(self):
"""Determine whether the objective function can be evaluated."""
return (not self.f or self.f.has_eval) and all([_.has_eval for _ in self.g_list])

def _itstat_extra_fields(self):
"""Define AcceleratedPAPGM iteration statistics fields."""
itstat_fields = {"L": "%9.3e", "Residual": "%9.3e"}
itstat_attrib = ["L", "norm_residual()"]
return itstat_fields, itstat_attrib

def norm_residual(self) -> float:
r"""Return the fixed point residual.

Return the fixed point residual (see Sec. 4.3 of
:cite:`liu-2018-first`).
"""
return self.fixed_point_residual
Loading