From 5ffd1f9046fa8e7e2e2f7d37846993a3ab2221dd Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 10 Nov 2023 15:12:52 -0700 Subject: [PATCH] Add proximal average implementation (#469) * Work in progress on proximal average * Resolve jit issues due to Python conditionals * Add __repr__ method * Fix Functional __repr__ method * Improve weight handling * Add API docs * Minor edits * Add test * Improve functional __repr__ formatting * Clean up bibtex * Fix dtype issues * Add example * Rename examples * Edit example docs * Update submodule * Edit example docs * Improve example script * Add option for excluding indicator function from sum of components in evaluation * Rename some example scripts * Update submodule * Add tests * Rename some example scripts --- data | 2 +- docs/source/examples.rst | 33 ++--- docs/source/references.bib | 26 ++-- examples/scripts/README.rst | 36 +++--- ...pp_bm3d_pgm.py => deconv_ppp_bm3d_apgm.py} | 0 .../{denoise_tv_pgm.py => denoise_tv_apgm.py} | 0 examples/scripts/index.rst | 33 ++--- .../{sparsecode_pgm.py => sparsecode_apgm.py} | 4 +- ...arsecode_admm.py => sparsecode_nn_admm.py} | 7 +- examples/scripts/sparsecode_nn_apgm.py | 99 +++++++++++++++ ...sson_pgm.py => sparsecode_poisson_apgm.py} | 0 scico/functional/__init__.py | 2 + scico/functional/_functional.py | 9 +- scico/functional/_proxavg.py | 113 ++++++++++++++++++ scico/functional/_tvnorm.py | 4 +- scico/test/functional/test_misc.py | 1 + scico/test/functional/test_proxavg.py | 61 ++++++++++ 17 files changed, 366 insertions(+), 64 deletions(-) rename examples/scripts/{deconv_ppp_bm3d_pgm.py => deconv_ppp_bm3d_apgm.py} (100%) rename examples/scripts/{denoise_tv_pgm.py => denoise_tv_apgm.py} (100%) rename examples/scripts/{sparsecode_pgm.py => sparsecode_apgm.py} (95%) rename examples/scripts/{sparsecode_admm.py => sparsecode_nn_admm.py} (91%) create mode 100644 examples/scripts/sparsecode_nn_apgm.py rename examples/scripts/{sparsecode_poisson_pgm.py => sparsecode_poisson_apgm.py} (100%) create mode 100644 scico/functional/_proxavg.py create mode 100644 scico/test/functional/test_proxavg.py diff --git a/data b/data index f99c8b524..84076ec7f 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit f99c8b524230289334d20fdc33fde07aa666b2e0 +Subproject commit 84076ec7f8cb743ea7081d091724aa15018b79a9 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 87bb48263..f985b54e6 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -53,7 +53,7 @@ Deconvolution examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/deconv_ppp_bm3d_admm - examples/deconv_ppp_bm3d_pgm + examples/deconv_ppp_bm3d_apgm examples/deconv_ppp_dncnn_admm examples/deconv_ppp_dncnn_padmm examples/deconv_ppp_bm4d_admm @@ -67,11 +67,12 @@ Sparse Coding .. toctree:: :maxdepth: 1 - examples/sparsecode_admm + examples/sparsecode_nn_admm + examples/sparsecode_nn_apgm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm - examples/sparsecode_pgm - examples/sparsecode_poisson_pgm + examples/sparsecode_apgm + examples/sparsecode_poisson_apgm Miscellaneous @@ -84,7 +85,7 @@ Miscellaneous examples/superres_ppp_dncnn_admm examples/denoise_l1tv_admm examples/denoise_tv_admm - examples/denoise_tv_pgm + examples/denoise_tv_apgm examples/denoise_tv_multi examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg @@ -113,7 +114,7 @@ Plug and Play Priors examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox examples/deconv_ppp_bm3d_admm - examples/deconv_ppp_bm3d_pgm + examples/deconv_ppp_bm3d_apgm examples/deconv_ppp_dncnn_admm examples/deconv_ppp_dncnn_padmm examples/deconv_ppp_bm4d_admm @@ -142,7 +143,7 @@ Total Variation examples/deconv_microscopy_allchn_tv_admm examples/denoise_l1tv_admm examples/denoise_tv_admm - examples/denoise_tv_pgm + examples/denoise_tv_apgm examples/denoise_tv_multi examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg @@ -157,11 +158,12 @@ Sparsity :maxdepth: 1 examples/diffusercam_tv_admm - examples/sparsecode_admm + examples/sparsecode_nn_admm + examples/sparsecode_nn_apgm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm - examples/sparsecode_pgm - examples/sparsecode_poisson_pgm + examples/sparsecode_apgm + examples/sparsecode_poisson_apgm examples/video_rpca_admm @@ -215,7 +217,7 @@ ADMM examples/deconv_ppp_dncnn_admm examples/deconv_ppp_bm4d_admm examples/diffusercam_tv_admm - examples/sparsecode_admm + examples/sparsecode_nn_admm examples/sparsecode_conv_admm examples/sparsecode_conv_md_admm examples/demosaic_ppp_bm3d_admm @@ -274,10 +276,11 @@ PGM .. toctree:: :maxdepth: 1 - examples/deconv_ppp_bm3d_pgm - examples/sparsecode_pgm - examples/sparsecode_poisson_pgm - examples/denoise_tv_pgm + examples/deconv_ppp_bm3d_apgm + examples/sparsecode_apgm + examples/sparsecode_nn_apgm + examples/sparsecode_poisson_apgm + examples/denoise_tv_apgm PCG diff --git a/docs/source/references.bib b/docs/source/references.bib index b611601bb..bd3eb4705 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -110,7 +110,7 @@ @InCollection {beck-2010-gradient pages = {42--88}, publisher = {Cambridge University Press}, year = 2010, - doi = {10.1017/CBO9780511804458.003}, + doi = {10.1017/CBO9780511804458.003}, url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} } @@ -124,7 +124,7 @@ @Software {bradbury-2018-jax {P}ython+{N}um{P}y programs}, url = {http://github.com/google/jax}, version = {0.2.5}, - year = {2018} + year = 2018 } @Book {beck-2017-first, @@ -420,11 +420,11 @@ @Article {kamilov-2023-plugandplay Learned Models in Computational Imaging}, journal = {IEEE Signal Processing Magazine}, year = 2023, - month = Jan, - volume = 40, - number = 1, - pages = {85--97}, - doi = {10.1109/MSP.2022.3199595} + month = Jan, + volume = 40, + number = 1, + pages = {85--97}, + doi = {10.1109/MSP.2022.3199595} } @Article {liu-2018-first, @@ -722,6 +722,18 @@ @Article {yang-2012-linearized pages = {301--329} } +@InProceedings {yu-2013-better, + author = {Yu, Yao-Liang}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {C.J. Burges and L. Bottou and M. Welling and + Z. Ghahramani and K.Q. Weinberger}, + title = {Better Approximation and Faster Algorithm Using the + Proximal Average}, + url = {https://proceedings.neurips.cc/paper_files/paper/2013/file/49182f81e6a13cf5eaa496d51fea6406-Paper.pdf}, + volume = 26, + year = 2013 +} + @Article {zhang-2017-dncnn, author = {Kai Zhang and Wangmeng Zuo and Yunjin Chen and Deyu Meng and Lei Zhang}, diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 50775448d..66d5d10be 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -61,7 +61,7 @@ Deconvolution Deconvolution Microscopy (All Channels) `deconv_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Deconvolution (ADMM Solver) - `deconv_ppp_bm3d_pgm.py `_ + `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) `deconv_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Deconvolution (ADMM Solver) @@ -78,15 +78,17 @@ Deconvolution Sparse Coding ^^^^^^^^^^^^^ - `sparsecode_admm.py `_ + `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) + `sparsecode_nn_apgm.py `_ + Non-Negative Basis Pursuit DeNoising (APGM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) `sparsecode_conv_md_admm.py `_ Convolutional Sparse Coding with Mask Decoupling (ADMM) - `sparsecode_pgm.py `_ + `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) - `sparsecode_poisson_pgm.py `_ + `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) @@ -101,7 +103,7 @@ Miscellaneous ℓ1 Total Variation Denoising `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) - `denoise_tv_pgm.py `_ + `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising @@ -139,7 +141,7 @@ Plug and Play Priors PPP (with BM3D) Fan-Beam CT Reconstruction `deconv_ppp_bm3d_admm.py `_ PPP (with BM3D) Image Deconvolution (ADMM Solver) - `deconv_ppp_bm3d_pgm.py `_ + `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) `deconv_ppp_dncnn_admm.py `_ PPP (with DnCNN) Image Deconvolution (ADMM Solver) @@ -186,7 +188,7 @@ Total Variation ℓ1 Total Variation Denoising `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) - `denoise_tv_pgm.py `_ + `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising @@ -204,15 +206,17 @@ Sparsity `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction - `sparsecode_admm.py `_ + `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) + `sparsecode_nn_apgm.py `_ + Non-Negative Basis Pursuit DeNoising (APGM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) `sparsecode_conv_md_admm.py `_ Convolutional Sparse Coding with Mask Decoupling (ADMM) - `sparsecode_pgm.py `_ + `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) - `sparsecode_poisson_pgm.py `_ + `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) `video_rpca_admm.py `_ Video Decomposition via Robust PCA @@ -289,7 +293,7 @@ ADMM PPP (with BM4D) Volume Deconvolution `diffusercam_tv_admm.py `_ TV-Regularized 3D DiffuserCam Reconstruction - `sparsecode_admm.py `_ + `sparsecode_nn_admm.py `_ Non-Negative Basis Pursuit DeNoising (ADMM) `sparsecode_conv_admm.py `_ Convolutional Sparse Coding (ADMM) @@ -352,13 +356,15 @@ PDHG PGM ^^^ - `deconv_ppp_bm3d_pgm.py `_ + `deconv_ppp_bm3d_apgm.py `_ PPP (with BM3D) Image Deconvolution (APGM Solver) - `sparsecode_pgm.py `_ + `sparsecode_apgm.py `_ Basis Pursuit DeNoising (APGM) - `sparsecode_poisson_pgm.py `_ + `sparsecode_nn_apgm.py `_ + Non-Negative Basis Pursuit DeNoising (APGM) + `sparsecode_poisson_apgm.py `_ Non-negative Poisson Loss Reconstruction (APGM) - `denoise_tv_pgm.py `_ + `denoise_tv_apgm.py `_ Total Variation Denoising with Constraint (APGM) diff --git a/examples/scripts/deconv_ppp_bm3d_pgm.py b/examples/scripts/deconv_ppp_bm3d_apgm.py similarity index 100% rename from examples/scripts/deconv_ppp_bm3d_pgm.py rename to examples/scripts/deconv_ppp_bm3d_apgm.py diff --git a/examples/scripts/denoise_tv_pgm.py b/examples/scripts/denoise_tv_apgm.py similarity index 100% rename from examples/scripts/denoise_tv_pgm.py rename to examples/scripts/denoise_tv_apgm.py diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 136e610c6..982c68f7b 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -37,7 +37,7 @@ Deconvolution - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - deconv_ppp_bm3d_admm.py - - deconv_ppp_bm3d_pgm.py + - deconv_ppp_bm3d_apgm.py - deconv_ppp_dncnn_admm.py - deconv_ppp_dncnn_padmm.py - deconv_ppp_bm4d_admm.py @@ -48,11 +48,12 @@ Deconvolution Sparse Coding ^^^^^^^^^^^^^ - - sparsecode_admm.py + - sparsecode_nn_admm.py + - sparsecode_nn_apgm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - - sparsecode_pgm.py - - sparsecode_poisson_pgm.py + - sparsecode_apgm.py + - sparsecode_poisson_apgm.py Miscellaneous @@ -62,7 +63,7 @@ Miscellaneous - superres_ppp_dncnn_admm.py - denoise_l1tv_admm.py - denoise_tv_admm.py - - denoise_tv_pgm.py + - denoise_tv_apgm.py - denoise_tv_multi.py - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py @@ -85,7 +86,7 @@ Plug and Play Priors - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - deconv_ppp_bm3d_admm.py - - deconv_ppp_bm3d_pgm.py + - deconv_ppp_bm3d_apgm.py - deconv_ppp_dncnn_admm.py - deconv_ppp_dncnn_padmm.py - deconv_ppp_bm4d_admm.py @@ -111,7 +112,7 @@ Total Variation - deconv_microscopy_allchn_tv_admm.py - denoise_l1tv_admm.py - denoise_tv_admm.py - - denoise_tv_pgm.py + - denoise_tv_apgm.py - denoise_tv_multi.py - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py @@ -123,11 +124,12 @@ Sparsity ^^^^^^^^ - diffusercam_tv_admm.py - - sparsecode_admm.py + - sparsecode_nn_admm.py + - sparsecode_nn_apgm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - - sparsecode_pgm.py - - sparsecode_poisson_pgm.py + - sparsecode_apgm.py + - sparsecode_poisson_apgm.py - video_rpca_admm.py @@ -172,7 +174,7 @@ ADMM - deconv_ppp_dncnn_admm.py - deconv_ppp_bm4d_admm.py - diffusercam_tv_admm.py - - sparsecode_admm.py + - sparsecode_nn_admm.py - sparsecode_conv_admm.py - sparsecode_conv_md_admm.py - demosaic_ppp_bm3d_admm.py @@ -216,10 +218,11 @@ PDHG PGM ^^^ - - deconv_ppp_bm3d_pgm.py - - sparsecode_pgm.py - - sparsecode_poisson_pgm.py - - denoise_tv_pgm.py + - deconv_ppp_bm3d_apgm.py + - sparsecode_apgm.py + - sparsecode_nn_apgm.py + - sparsecode_poisson_apgm.py + - denoise_tv_apgm.py PCG diff --git a/examples/scripts/sparsecode_pgm.py b/examples/scripts/sparsecode_apgm.py similarity index 95% rename from examples/scripts/sparsecode_pgm.py rename to examples/scripts/sparsecode_apgm.py index f5d34e08e..2d7ef7800 100644 --- a/examples/scripts/sparsecode_pgm.py +++ b/examples/scripts/sparsecode_apgm.py @@ -35,10 +35,10 @@ σ = 0.5 # Noise level np.random.seed(12345) -D = np.random.randn(m, n) +D = np.random.randn(m, n).astype(np.float32) L0 = np.linalg.norm(D, 2) ** 2 -x_gt = np.zeros(n) # true signal +x_gt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.permutation(list(range(0, n - 1))) x_gt[idx[0:s]] = np.random.randn(s) y = D @ x_gt + σ * np.random.randn(m) # synthetic signal diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_nn_admm.py similarity index 91% rename from examples/scripts/sparsecode_admm.py rename to examples/scripts/sparsecode_nn_admm.py index 829ccbff1..6d48f7b2c 100644 --- a/examples/scripts/sparsecode_admm.py +++ b/examples/scripts/sparsecode_nn_admm.py @@ -17,6 +17,9 @@ where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, $\mathbf{x}$ is the sparse representation, and $I(\mathbf{x} \geq 0)$ is the non-negative indicator. + +In this example the problem is solved via ADMM, while Accelerated PGM is +used in a [companion example](sparsecode_nn_apgm.rst). """ import numpy as np @@ -36,10 +39,10 @@ s = 10 # sparsity level np.random.seed(1) -D = np.random.randn(m, n) +D = np.random.randn(m, n).astype(np.float32) D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary -xt = np.zeros(n) # true signal +xt = np.zeros(n, dtype=np.float32) # true signal idx = np.random.randint(low=0, high=n, size=s) # support of xt xt[idx] = np.random.rand(s) y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal diff --git a/examples/scripts/sparsecode_nn_apgm.py b/examples/scripts/sparsecode_nn_apgm.py new file mode 100644 index 000000000..d980ce7e2 --- /dev/null +++ b/examples/scripts/sparsecode_nn_apgm.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Non-Negative Basis Pursuit DeNoising (APGM) +=========================================== + +This example demonstrates the solution of a non-negative sparse coding +problem + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - D \mathbf{x} \|_2^2 + + \lambda \| \mathbf{x} \|_1 + I(\mathbf{x} \geq 0) \;,$$ + +where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, +$\mathbf{x}$ is the sparse representation, and $I(\mathbf{x} \geq 0)$ +is the non-negative indicator. + +In this example the problem is solved via Accelerated PGM, using the +proximal averaging method :cite:`yu-2013-better` to approximate the +proximal operator of the sum of the $\ell_1$ norm and an indicator +function, while ADMM is used in a +[companion example](sparsecode_nn_admm.rst). +""" + +import numpy as np + +import scico.numpy as snp +from scico import functional, linop, loss, plot +from scico.optimize.pgm import AcceleratedPGM +from scico.util import device_info + +""" +Create random dictionary, reference random sparse representation, and +test signal consisting of the synthesis of the reference sparse +representation. +""" +m = 32 # signal size +n = 128 # dictionary size +s = 10 # sparsity level + +np.random.seed(1) +D = np.random.randn(m, n).astype(np.float32) +D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary +L0 = max(np.linalg.norm(D, 2) ** 2, 5e1) + +xt = np.zeros(n, dtype=np.float32) # true signal +idx = np.random.randint(low=0, high=n, size=s) # support of xt +xt[idx] = np.random.rand(s) +y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal + +xt = snp.array(xt) # convert to jax array +y = snp.array(y) # convert to jax array + + +""" +Set up the forward operator and APGM solver object. +""" +lmbda = 2e-1 +A = linop.MatrixOperator(D) +f = loss.SquaredL2Loss(y=y, A=A) +g = functional.ProximalAverage([lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]) +maxiter = 250 # number of APGM iterations +solver = AcceleratedPGM( + f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 20} +) + + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +x = solver.solve() + + +""" +Plot the recovered coefficients and signal. +""" +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) +plot.plot( + np.vstack((xt, solver.x)).T, + title="Coefficients", + lgnd=("Ground Truth", "Recovered"), + fig=fig, + ax=ax[0], +) +plot.plot( + np.vstack((D @ xt, y, D @ solver.x)).T, + title="Signal", + lgnd=("Ground Truth", "Noisy", "Recovered"), + fig=fig, + ax=ax[1], +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_apgm.py similarity index 100% rename from examples/scripts/sparsecode_poisson_pgm.py rename to examples/scripts/sparsecode_poisson_apgm.py diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 48509cd40..e8996edec 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -22,6 +22,7 @@ L1MinusL2Norm, ) from ._tvnorm import AnisotropicTVNorm +from ._proxavg import ProximalAverage from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN from ._dist import SetDistance, SquaredSetDistance @@ -43,6 +44,7 @@ "NonNegativeIndicator", "NuclearNorm", "L2BallIndicator", + "ProximalAverage", "SetDistance", "SquaredSetDistance", "BM3D", diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 94bfb6053..3216dac40 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -36,10 +36,7 @@ def __init__(self): self._grad = scico.grad(self.__call__) def __repr__(self): - return f"""{type(self)} -has_eval = {self.has_eval} -has_prox = {self.has_prox} - """ + return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})""" def __mul__(self, other): if snp.isscalar(other) or isinstance(other, jax.core.Tracer): @@ -151,7 +148,9 @@ class ScaledFunctional(Functional): r"""A functional multiplied by a scalar.""" def __repr__(self): - return "Scaled functional of type " + str(type(self.functional)) + return ( + "Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})" + ) def __init__(self, functional: Functional, scale: float): self.functional = functional diff --git a/scico/functional/_proxavg.py b/scico/functional/_proxavg.py new file mode 100644 index 000000000..9a06f6f61 --- /dev/null +++ b/scico/functional/_proxavg.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Implementation of the proximal average method.""" + +from typing import List, Optional, Union + +from scico.numpy import Array, BlockArray, isinf + +from ._functional import Functional + + +class ProximalAverage(Functional): + """Weighted average of functionals. + + A functional that is composed of a weighted average of functionals. + All of the component functionals are required to have proximal + operators. The proximal operator of the composite functional is + approximated via the proximal average method :cite:`yu-2013-better`, + which holds for small scaling parameters. This does not imply that it + can only be applied to problems requiring a small regularization + parameter since most proximal algorithms include an additional + algorithm parameter that also plays a role in the parameter of the + proximal operator. For example, in :class:`.PGM` and + :class:`.AcceleratedPGM`, the scaled proximal operator parameter + is the regularization parameter divided by the `L0` algorithm + parameter, and for :class:`.ADMM`, the scaled proximal operator + parameters are the regularization parameters divided by the entries + in the `rho_list` algorithm parameter. + """ + + def __init__( + self, + func_list: List[Functional], + alpha_list: Optional[List[float]] = None, + no_inf_eval=True, + ): + """ + Args: + func_list: List of component :class:`.Functional` objects, + all of which must have a proximal operator. + alpha_list: List of scalar weights for each + :class:`.Functional`. If not specified, defaults to equal + weights. If specified, the list of weights must have the + same length as the :class:`.Functional` list. If the + weights do not sum to unity, they are scaled to ensure + that they do. + no_inf_eval: If ``True``, exclude infinite values (typically + associated with a functional that is an indicator + function) from the evaluation of the sum of component + functionals. + """ + self.has_prox = all([f.has_prox for f in func_list]) + if not self.has_prox: + raise ValueError("All functionals in func_list must have has_prox == True.") + self.has_eval = all([f.has_eval for f in func_list]) + self.no_inf_eval = no_inf_eval + self.func_list = func_list + N = len(func_list) + if alpha_list is None: + self.alpha_list = [1.0 / N] * N + else: + if len(alpha_list) != N: + raise ValueError("If specified, alpha_list must have the same length as func_list") + alpha_sum = sum(alpha_list) + if alpha_sum != 1.0: + alpha_list = [alpha / alpha_sum for alpha in alpha_list] + self.alpha_list = alpha_list + + def __repr__(self): + return ( + Functional.__repr__(self) + + "\n Weights: " + + ", ".join([str(alpha) for alpha in self.alpha_list]) + + "\n Components:\n" + + "\n".join([" " + repr(f) for f in self.func_list]) + ) + + def __call__(self, x: Union[Array, BlockArray]) -> float: + """Evaluate the weighted average of component functionals.""" + if self.has_eval: + weight_func_vals = [alpha * f(x) for (alpha, f) in zip(self.alpha_list, self.func_list)] + if self.no_inf_eval: + weight_func_vals = list(filter(lambda x: not isinf(x), weight_func_vals)) + return sum(weight_func_vals) + else: + raise ValueError("At least one functional in func_list has has_eval == False.") + + def prox( + self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs + ) -> Union[Array, BlockArray]: + r"""Approximate proximal operator of the average of functionals. + + Approximation of the proximal operator of a weighted average of + functionals computed via the proximal average method + :cite:`yu-2013-better`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + return sum( + [ + alpha * f.prox(v, lam, **kwargs) + for (alpha, f) in zip(self.alpha_list, self.func_list) + ] + ) diff --git a/scico/functional/_tvnorm.py b/scico/functional/_tvnorm.py index b8d621c93..e3d2d067e 100644 --- a/scico/functional/_tvnorm.py +++ b/scico/functional/_tvnorm.py @@ -57,7 +57,7 @@ class AnisotropicTVNorm(Functional): has_prox = True def __init__(self, ndims: Optional[int] = None): - r""" + """ Args: ndims: Number of (trailing) dimensions of the input over which to apply the finite difference operator. If @@ -71,7 +71,7 @@ def __init__(self, ndims: Optional[int] = None): self.W: Optional[LinearOperator] = None def __call__(self, x: Array) -> float: - r"""Compute the anisotropic TV norm of an array.""" + """Compute the anisotropic TV norm of an array.""" if self.G is None or self.G.shape[1] != x.shape: if self.ndims is None: ndims = x.ndim diff --git a/scico/test/functional/test_misc.py b/scico/test/functional/test_misc.py index 1edcd0174..3c8ac98b4 100644 --- a/scico/test/functional/test_misc.py +++ b/scico/test/functional/test_misc.py @@ -17,6 +17,7 @@ class TestCheckAttrs: functional.Functional, functional.ScaledFunctional, functional.SeparableFunctional, + functional.ProximalAverage, ] to_check = [] for name, cls in functional.__dict__.items(): diff --git a/scico/test/functional/test_proxavg.py b/scico/test/functional/test_proxavg.py new file mode 100644 index 000000000..69f10ceba --- /dev/null +++ b/scico/test/functional/test_proxavg.py @@ -0,0 +1,61 @@ +import numpy as np + +import pytest + +import scico.numpy as snp +from scico import functional, linop, loss, metric +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.optimize.pgm import AcceleratedPGM + + +def test_proxavg_init(): + g0 = functional.L1Norm() + g1 = functional.L2Norm() + + with pytest.raises(ValueError): + h = functional.ProximalAverage( + [g0, g1], + alpha_list=[ + 0.1, + ], + ) + + h = functional.ProximalAverage([g0, g1], alpha_list=[0.1, 0.1]) + assert sum(h.alpha_list) == 1.0 + + g1.has_prox = False + with pytest.raises(ValueError): + h = functional.ProximalAverage([g0, g1]) + + +def test_proxavg(): + N = 128 + g = np.linspace(0, 2 * np.pi, N, dtype=np.float32) + y = np.sin(2 * g) + y[y > 0.5] = 0.5 + y[y < -0.5] = -0.5 + y *= 2 + y = snp.array(y) + + λ0 = 6e-1 + λ1 = 6e-1 + f = loss.SquaredL2Loss(y=y) + g0 = λ0 * functional.L1Norm() + g1 = λ1 * functional.L2Norm() + + solver = ADMM( + f=f, + g_list=[0.5 * g0, 0.5 * g1], + C_list=[linop.Identity(y.shape), linop.Identity(y.shape)], + rho_list=[1e1, 1e1], + x0=y, + maxiter=100, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-5, "maxiter": 20}), + ) + x_admm = solver.solve() + + h = functional.ProximalAverage([λ0 * functional.L1Norm(), λ1 * functional.L2Norm()]) + solver = AcceleratedPGM(f=f, g=h, L0=3.4e2, x0=y, maxiter=250) + x_prxavg = solver.solve() + + assert metric.snr(x_admm, x_prxavg) > 50