diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index a690df602..99542a9c0 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -74,7 +74,7 @@ x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": num_inner_iter}), - verbose=True, + itstat_options={"display": True, period: 5}, ) diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index f86fceb30..c5e4f263e 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -118,7 +118,7 @@ def postprocess(x): x0=x0, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}), - verbose=True, + itstat_options={"display": True, period: 10}, ) print(f"Solving on {device_info()}\n") admm_unweighted.solve() @@ -148,7 +148,7 @@ def postprocess(x): maxiter=maxiter, x0=x0, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}), - verbose=True, + itstat_options={"display": True, period: 10}, ) admm_weighted.solve() x_weighted = postprocess(admm_weighted.x) diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index 382855d3b..d06c3a65d 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -103,7 +103,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), - verbose=True, + itstat_options={"display": True, period: 1}, ) diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 117d8dd59..98a896eb9 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -105,7 +105,7 @@ x0=x0, maxiter=20, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), - verbose=True, + itstat_options={"display": True}, ) diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 5e79efa09..2ac51bd9b 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -104,12 +104,12 @@ x0=x0, maxiter=50, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 10}), - verbose=True, + itstat_options={"display": True, period: 10}, ) print(f"Solving on {device_info()}\n") x_admm = solve_admm.solve() hist_admm = solve_admm.itstat_object.history(transpose=True) -print(metric.psnr(x_gt, x_admm)) +print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n") """ @@ -123,11 +123,11 @@ nu=2e-1, x0=x0, maxiter=50, - verbose=True, + itstat_options={"display": True, period: 10}, ) x_ladmm = solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) -print(metric.psnr(x_gt, x_ladmm)) +print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n") """ @@ -141,11 +141,11 @@ sigma=8e0, x0=x0, maxiter=50, - verbose=True, + itstat_options={"display": True, period: 10}, ) x_pdhg = solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) -print(metric.psnr(x_gt, x_pdhg)) +print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n") """ diff --git a/examples/scripts/deconv_circ_tv_admm.py b/examples/scripts/deconv_circ_tv_admm.py index 1214c716a..2b7f1b852 100644 --- a/examples/scripts/deconv_circ_tv_admm.py +++ b/examples/scripts/deconv_circ_tv_admm.py @@ -75,7 +75,7 @@ x0=A.adj(y), maxiter=maxiter, subproblem_solver=CircularConvolveSolver(), - verbose=True, + itstat_options={"display": True, period: 10}, ) diff --git a/examples/scripts/deconv_microscopy_allchn_tv_admm.py b/examples/scripts/deconv_microscopy_allchn_tv_admm.py index 0e3e22bc5..2fbd2ac01 100644 --- a/examples/scripts/deconv_microscopy_allchn_tv_admm.py +++ b/examples/scripts/deconv_microscopy_allchn_tv_admm.py @@ -188,16 +188,16 @@ def deconvolve_channel(channel): g2 = functional.NonNegativeIndicator() # non-negativity constraint if channel == 0: print("Displaying solver status for channel 0") - verbose = True + display = True else: - verbose = False + display = False solver = ADMM( f=None, g_list=[g0, g1, g2], C_list=[C0, C1, C2], rho_list=[ρ0, ρ1, ρ2], maxiter=maxiter, - verbose=verbose, + itstat_options={"display": display, period: 10}, x0=y_pad, subproblem_solver=CircularConvolveSolver(), ) diff --git a/examples/scripts/deconv_microscopy_tv_admm.py b/examples/scripts/deconv_microscopy_tv_admm.py index 3e5d1c175..805d88e06 100644 --- a/examples/scripts/deconv_microscopy_tv_admm.py +++ b/examples/scripts/deconv_microscopy_tv_admm.py @@ -166,7 +166,7 @@ def block_avg(im, N): C_list=[C0, C1, C2], rho_list=[ρ0, ρ1, ρ2], maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, x0=y_pad, subproblem_solver=CircularConvolveSolver(), ) diff --git a/examples/scripts/deconv_ppp_bm3d_admm.py b/examples/scripts/deconv_ppp_bm3d_admm.py index 274d9523f..72c7605d2 100644 --- a/examples/scripts/deconv_ppp_bm3d_admm.py +++ b/examples/scripts/deconv_ppp_bm3d_admm.py @@ -69,7 +69,7 @@ x0=A.T @ y, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), - verbose=True, + itstat_options={"display": True}, ) diff --git a/examples/scripts/deconv_ppp_bm3d_pgm.py b/examples/scripts/deconv_ppp_bm3d_pgm.py index 6b208b4fc..e601d0775 100644 --- a/examples/scripts/deconv_ppp_bm3d_pgm.py +++ b/examples/scripts/deconv_ppp_bm3d_pgm.py @@ -62,7 +62,9 @@ maxiter = 50 # number of APGM iterations -solver = AcceleratedPGM(f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, verbose=True) +solver = AcceleratedPGM( + f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True, "period": 10} +) print(f"Solving on {device_info()}\n") x = solver.solve() diff --git a/examples/scripts/deconv_ppp_dncnn_admm.py b/examples/scripts/deconv_ppp_dncnn_admm.py index d54d2bf27..2a688ce27 100644 --- a/examples/scripts/deconv_ppp_dncnn_admm.py +++ b/examples/scripts/deconv_ppp_dncnn_admm.py @@ -74,7 +74,7 @@ x0=A.T @ y, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 30}), - verbose=True, + itstat_options={"display": True}, ) diff --git a/examples/scripts/deconv_tv_admm.py b/examples/scripts/deconv_tv_admm.py index fc632fd7c..0f1a91641 100644 --- a/examples/scripts/deconv_tv_admm.py +++ b/examples/scripts/deconv_tv_admm.py @@ -74,7 +74,7 @@ x0=A.adj(y), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(), - verbose=True, + itstat_options={"display": True, period: 10}, ) diff --git a/examples/scripts/deconv_tv_admm_tune.py b/examples/scripts/deconv_tv_admm_tune.py index 2ea661123..54cad9dbb 100644 --- a/examples/scripts/deconv_tv_admm_tune.py +++ b/examples/scripts/deconv_tv_admm_tune.py @@ -64,7 +64,6 @@ def eval_params(config): x0=A.adj(y), maxiter=5, subproblem_solver=LinearSubproblemSolver(), - verbose=False, ) # Perform 50 iterations, reporting performance to ray.tune every 5 iterations. for step in range(10): diff --git a/examples/scripts/demosaic_ppp_bm3d_admm.py b/examples/scripts/demosaic_ppp_bm3d_admm.py index 0f1c1acbd..29d72bc3f 100644 --- a/examples/scripts/demosaic_ppp_bm3d_admm.py +++ b/examples/scripts/demosaic_ppp_bm3d_admm.py @@ -115,7 +115,7 @@ def demosaic(cfaimg): x0=imgb, maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), - verbose=True, + itstat_options={"display": True}, ) diff --git a/examples/scripts/denoise_tv_iso_admm.py b/examples/scripts/denoise_tv_iso_admm.py index 6e4fffa30..dab66ad19 100644 --- a/examples/scripts/denoise_tv_iso_admm.py +++ b/examples/scripts/denoise_tv_iso_admm.py @@ -70,13 +70,13 @@ x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}), - verbose=True, + itstat_options={"display": True, period: 10}, ) print(f"Solving on {device_info()}\n") solver.solve() x_iso = solver.x - +print() """ Denoise with anisotropic total variation for comparison. @@ -93,11 +93,12 @@ x0=y, maxiter=100, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}), - verbose=True, + itstat_options={"display": True, period: 10}, ) solver.solve() x_aniso = solver.x +print() """ diff --git a/examples/scripts/denoise_tv_iso_multi.py b/examples/scripts/denoise_tv_iso_multi.py index 2e0b2f6b2..eed1ed3ab 100644 --- a/examples/scripts/denoise_tv_iso_multi.py +++ b/examples/scripts/denoise_tv_iso_multi.py @@ -71,7 +71,6 @@ x0=y, maxiter=1, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 1}), - verbose=False, ) solver_admm.solve() # trailing semi-colon suppresses output in notebook @@ -88,7 +87,7 @@ x0=y, maxiter=200, subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 2}), - verbose=True, + itstat_options={"display": True, period: 10}, ) print(f"Solving on {device_info()}\n") solver_admm.solve() @@ -106,7 +105,7 @@ nu=1e-1, x0=y, maxiter=200, - verbose=True, + itstat_options={"display": True, period: 10}, ) solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) @@ -122,7 +121,7 @@ tau=4e-1, sigma=4e-1, maxiter=200, - verbose=True, + itstat_options={"display": True, period: 10}, ) solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index 86fa48047..d3f955905 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -144,7 +144,7 @@ def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray: L0=16.0 * f_iso.lmbda ** 2, x0=x0, maxiter=100, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=RobustLineSearchStepSize(), ) @@ -194,7 +194,7 @@ def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray: L0=16.0 * f.lmbda ** 2, x0=x0, maxiter=100, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=RobustLineSearchStepSize(), ) diff --git a/examples/scripts/sparsecode_admm.py b/examples/scripts/sparsecode_admm.py index 7965672c6..6ca0bc4b3 100644 --- a/examples/scripts/sparsecode_admm.py +++ b/examples/scripts/sparsecode_admm.py @@ -69,7 +69,7 @@ x0=A.adj(y), maxiter=maxiter, subproblem_solver=LinearSubproblemSolver(), - verbose=True, + itstat_options={"display": True, period: 10}, ) diff --git a/examples/scripts/sparsecode_pgm.py b/examples/scripts/sparsecode_pgm.py index 1fd0eb5a9..7bfa62593 100644 --- a/examples/scripts/sparsecode_pgm.py +++ b/examples/scripts/sparsecode_pgm.py @@ -58,7 +58,9 @@ A = linop.MatrixOperator(D) f = loss.SquaredL2Loss(y=y, A=A) g = λ * functional.L1Norm() -solver = AcceleratedPGM(f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, verbose=True) +solver = AcceleratedPGM( + f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 10} +) """ diff --git a/examples/scripts/sparsecode_poisson_blkarr_pgm.py b/examples/scripts/sparsecode_poisson_blkarr_pgm.py index 432755971..5ab3f6c14 100644 --- a/examples/scripts/sparsecode_poisson_blkarr_pgm.py +++ b/examples/scripts/sparsecode_poisson_blkarr_pgm.py @@ -188,7 +188,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, ) str_ss = type(solver.step_size).__name__ @@ -216,7 +216,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=BBStepSize(), ) str_ss = type(solver.step_size).__name__ @@ -244,7 +244,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=AdaptiveBBStepSize(kappa=0.75), ) str_ss = type(solver.step_size).__name__ @@ -272,7 +272,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=LineSearchStepSize(), ) str_ss = type(solver.step_size).__name__ @@ -300,7 +300,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=RobustLineSearchStepSize(), ) str_ss = type(solver.step_size).__name__ diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_pgm.py index 2f841113c..641eef581 100644 --- a/examples/scripts/sparsecode_poisson_pgm.py +++ b/examples/scripts/sparsecode_poisson_pgm.py @@ -159,7 +159,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, ) str_ss = type(solver.step_size).__name__ @@ -187,7 +187,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=BBStepSize(), ) str_ss = type(solver.step_size).__name__ @@ -215,7 +215,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=AdaptiveBBStepSize(kappa=0.75), ) str_ss = type(solver.step_size).__name__ @@ -243,7 +243,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=LineSearchStepSize(gamma_u=1.01), ) str_ss = type(solver.step_size).__name__ @@ -271,7 +271,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat): L0=L0, x0=x0, maxiter=maxiter, - verbose=True, + itstat_options={"display": True, period: 10}, step_size=RobustLineSearchStepSize(), ) str_ss = type(solver.step_size).__name__ diff --git a/scico/admm.py b/scico/admm.py index 202be7d43..60cde9c3a 100644 --- a/scico/admm.py +++ b/scico/admm.py @@ -12,7 +12,7 @@ from __future__ import annotations from functools import reduce -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import jax from jax.scipy.sparse.linalg import cg as jax_cg @@ -397,8 +397,7 @@ def __init__( x0: Optional[Union[JaxArray, BlockArray]] = None, maxiter: int = 100, subproblem_solver: Optional[SubproblemSolver] = None, - verbose: bool = False, - itstat: Optional[Tuple[dict, Callable]] = None, + itstat_options: Optional[dict] = None, ): r"""Initialize an :class:`ADMM` object. @@ -416,16 +415,16 @@ def __init__( subproblem_solver: Solver for :math:`\mb{x}`-update step. Defaults to ``None``, which implies use of an instance of :class:`GenericSubproblemSolver`. - verbose: Flag indicating whether iteration statistics should - be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` - is a dict suitable for passing to the `fields` argument - of the :class:`.diagnostics.IterationStats` initializer, - and `insertfunc` is a function with two parameters, an - integer and an ADMM object, responsible for constructing - a tuple ready for insertion into the - :class:`.diagnostics.IterationStats` object. If None, - default values are used for the tuple components. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and an ADMM object, responsible + for constructing a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If ``None``, + default values are used for the dict entries, otherwise + the default dict is updated with the dict specified by + this parameter. """ N = len(g_list) if len(C_list) != N: @@ -445,48 +444,52 @@ def __init__( subproblem_solver = GenericSubproblemSolver() self.subproblem_solver: SubproblemSolver = subproblem_solver self.subproblem_solver.internal_init(self) - self.verbose: bool = verbose - if itstat: - itstat_dict = itstat[0] - itstat_func = itstat[1] - elif itstat is None: - if all([_.has_eval for _ in self.g_list]): - itstat_dict = { - "Iter": "%d", - "Time": "%8.2e", - "Objective": "%8.3e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.itnum, - obj.timer.elapsed(), - obj.objective(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - else: - # At least one 'g' can't be evaluated, so drop objective from the default itstat - itstat_dict = { - "Iter": "%d", - "Time": "%8.1e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.itnum, - obj.timer.elapsed(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - - self.itstat_object = IterationStats(itstat_dict, display=verbose) - self.itstat_insert_func = itstat_func + if all([_.has_eval for _ in self.g_list]): + # All 'g' functions can be evaluated, so objective function can be evaluated + itstat_fields = { + "Iter": "%d", + "Time": "%8.2e", + "Objective": "%8.3e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.objective(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + else: + # At least one 'g' can't be evaluated, so drop objective from the default itstat + itstat_fields = { + "Iter": "%d", + "Time": "%8.1e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + default_itstat_options = { + "fields": itstat_fields, + "itstat_func": itstat_func, + "display": False, + } + if itstat_options: + default_itstat_options.update(itstat_options) + self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: input_shape = C_list[0].input_shape @@ -678,4 +681,5 @@ def solve( self.timer.start() self.timer.stop() self.itnum += 1 + self.itstat_object.end() return self.x diff --git a/scico/diagnostics.py b/scico/diagnostics.py index 802269596..735e443e7 100644 --- a/scico/diagnostics.py +++ b/scico/diagnostics.py @@ -26,6 +26,8 @@ def __init__( fields: OrderedDict, ident: Optional[dict] = None, display: bool = False, + period: int = 1, + overwrite: bool = True, colsep: int = 2, ): """ @@ -51,6 +53,10 @@ def __init__( namedtuple used to record results. Defaults to ``None``. display: Flag indicating whether results should be printed to stdout. Defaults to ``False``. + period: Only display one result in every cycle of length + `period`. + overwrite: If ``True``, display all results, but each one + overwrites the next, except for one result per cycle. colsep: Number of spaces seperating fields in displayed tables. Defaults to 2. @@ -62,6 +68,10 @@ def __init__( # that field order is retained if not isinstance(fields, dict): raise TypeError("Parameter fields must be an instance of dict") + # Subsampling rate of results that are to be displayed + self.period = period + # Flag indicating whether to display and overwrite, or not display at all + self.overwrite = overwrite # Number of spaces seperating fields in displayed tables self.colsep = colsep # Main list of inserted values @@ -134,8 +144,7 @@ def __init__( ) def insert(self, values: Union[List, Tuple]): - """ - Insert a list of values for a single iteration. + """Insert a list of values for a single iteration. Args: values: Statistics for a single iteration. @@ -147,11 +156,29 @@ def insert(self, values: Union[List, Tuple]): if self.disphdr is not None: print(self.disphdr) self.disphdr = None - print((" " * self.colsep).join(self.fieldformat) % values) + if self.overwrite: + if (len(self.iterations) - 1) % self.period == 0: + end = "\n" + else: + end = "\r" + print((" " * self.colsep).join(self.fieldformat) % values, end=end) + else: + if (len(self.iterations) - 1) % self.period == 0: + print((" " * self.colsep).join(self.fieldformat) % values) - def history(self, transpose: bool = False): + def end(self): + """Mark end of iterations. + + This method should be called at the end of a set of iterations. + Its only function is to ensure that the displayed output is left + in an appropriate state when overwriting is active with a display + period other than unity. """ - Retrieve record of all inserted iterations. + if self.overwrite and self.period > 1 and (len(self.iterations) - 1) % self.period: + print() + + def history(self, transpose: bool = False): + """Retrieve record of all inserted iterations. Args: transpose: Flag indicating whether results should be returned diff --git a/scico/ladmm.py b/scico/ladmm.py index 03b3575bd..bb1940c01 100644 --- a/scico/ladmm.py +++ b/scico/ladmm.py @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import scico.numpy as snp from scico.blockarray import BlockArray @@ -98,8 +98,7 @@ def __init__( nu: float, x0: Optional[Union[JaxArray, BlockArray]] = None, maxiter: int = 100, - verbose: bool = False, - itstat: Optional[Tuple[dict, Callable]] = None, + itstat_options: Optional[dict] = None, ): r"""Initialize a :class:`LinearizedADMM` object. @@ -112,16 +111,16 @@ def __init__( x0: Starting point for :math:`\mb{x}`. If None, defaults to an array of zeros. maxiter: Number of ADMM outer-loop iterations. Default: 100. - verbose: Flag indicating whether iteration statistics should - be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` - is a dict suitable for passing to the `fields` argument - of the :class:`.diagnostics.IterationStats` initializer, - and `insertfunc` is a function with two parameters, an - integer and a LinearizedADMM object, responsible for - constructing a tuple ready for insertion into the - :class:`.diagnostics.IterationStats` object. If None, - default values are used for the tuple components. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and an ADMM object, responsible + for constructing a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If ``None``, + default values are used for the dict entries, otherwise + the default dict is updated with the dict specified by + this parameter. """ self.f: Functional = f self.g: Functional = g @@ -131,49 +130,52 @@ def __init__( self.itnum: int = 0 self.maxiter: int = maxiter self.timer: Timer = Timer() - self.verbose: bool = verbose - - if itstat: - itstat_dict = itstat[0] - itstat_func = itstat[1] - elif itstat is None: - if g.has_eval: - itstat_dict = { - "Iter": "%d", - "Time": "%8.2e", - "Objective": "%8.3e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.itnum, - obj.timer.elapsed(), - obj.objective(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - - else: - # At least one 'g' can't be evaluated, so drop objective from the default itstat - itstat_dict = { - "Iter": "%d", - "Time": "%8.1e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.i, - obj.timer.elapsed(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - - self.itstat_object = IterationStats(itstat_dict, display=verbose) - self.itstat_insert_func = itstat_func + + if g.has_eval: + # The 'g' functions can be evaluated, so objective function can be evaluated + itstat_fields = { + "Iter": "%d", + "Time": "%8.2e", + "Objective": "%8.3e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.objective(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + else: + # The 'g' function can't be evaluated, so drop objective from the default itstat + itstat_fields = { + "Iter": "%d", + "Time": "%8.1e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + default_itstat_options = { + "fields": itstat_fields, + "itstat_func": itstat_func, + "display": False, + } + if itstat_options: + default_itstat_options.update(itstat_options) + self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: input_shape = C.input_shape @@ -343,4 +345,5 @@ def solve( self.timer.start() self.timer.stop() self.itnum += 1 + self.itstat_object.end() return self.x diff --git a/scico/pgm.py b/scico/pgm.py index d39299679..291981b3d 100644 --- a/scico/pgm.py +++ b/scico/pgm.py @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import jax @@ -390,8 +390,7 @@ def __init__( x0: Union[JaxArray, BlockArray], step_size: Optional[PGMStepSize] = None, maxiter: int = 100, - verbose: bool = False, - itstat: Optional[Tuple[dict, Callable]] = None, + itstat_options: Optional[dict] = None, ): r""" @@ -406,15 +405,16 @@ def __init__( Default: 100. verbose: Flag indicating whether iteration statistics should be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` - is a dict suitable - for passing to the `fields` argument of the - :class:`.diagnostics.IterationStats` initializer, and - `insertfunc` is a function with two parameters, an - integer and a PGM object, responsible for constructing a - tuple ready for insertion into the - :class:`.diagnostics.IterationStats` object. If None, - default values are used for the tuple components. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and a PGM object, responsible + for constructing a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If ``None``, + default values are used for the dict entries, otherwise + the default dict is updated with the dict specified by + this parameter. """ if f.is_smooth is not True: @@ -444,12 +444,8 @@ def x_step(v, L): self.x_step = jax.jit(x_step) - self.verbose = verbose - if itstat: - itstat_dict = itstat[0] - itstat_func = itstat[1] - elif g.has_eval: - itstat_dict = { + if g.has_eval: + itstat_fields = { "Iter": "%d", "Time": "%8.2e", "Objective": "%8.3e", @@ -464,11 +460,19 @@ def x_step(v, L): pgm.norm_residual(), ) else: - itstat_dict = {"Iter": "%d", "Time": "%8.2e", "Residual": "%8.3e"} + itstat_fields = {"Iter": "%d", "Time": "%8.2e", "Residual": "%8.3e"} itstat_func = lambda pgm: (pgm.itnum, pgm.timer.elapsed(), pgm.norm_residual()) - self.itstat_object = IterationStats(itstat_dict, display=verbose) - self.itstat_insert_func = itstat_func + default_itstat_options = { + "fields": itstat_fields, + "itstat_func": itstat_func, + "display": False, + } + if itstat_options: + default_itstat_options.update(itstat_options) + self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_object = IterationStats(**default_itstat_options) + self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution def objective(self, x) -> float: @@ -532,6 +536,7 @@ def solve( self.timer.start() self.timer.stop() self.itnum += 1 + self.itstat_object.end() return self.x @@ -555,8 +560,7 @@ def __init__( x0: Union[JaxArray, BlockArray], step_size: Optional[PGMStepSize] = None, maxiter: int = 100, - verbose: bool = False, - itstat: Optional[Union[tuple, list]] = None, + itstat_options: Optional[dict] = None, ): r""" @@ -571,14 +575,16 @@ def __init__( Default: 100. verbose: Flag indicating whether iteration statistics should be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` - is a dict suitable for passing to the `fields` argument - of the :class:`.diagnostics.IterationStats` initializer, - and `insertfunc` is a function with two parameters, an - integer and a PGM object, responsible for constructing a - tuple ready for insertion into the - :class:`.diagnostics.IterationStats` object. If None, - default values are used for the tuple components. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and a PGM object, responsible + for constructing a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If ``None``, + default values are used for the dict entries, otherwise + the default dict is updated with the dict specified by + this parameter. """ x0 = ensure_on_device(x0) super().__init__( @@ -588,8 +594,7 @@ def __init__( x0=x0, step_size=step_size, maxiter=maxiter, - verbose=verbose, - itstat=itstat, + itstat_options=itstat_options, ) self.v = x0 diff --git a/scico/primaldual.py b/scico/primaldual.py index 39874e8d5..906901ba1 100644 --- a/scico/primaldual.py +++ b/scico/primaldual.py @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import scico.numpy as snp from scico.blockarray import BlockArray @@ -98,8 +98,7 @@ def __init__( x0: Optional[Union[JaxArray, BlockArray]] = None, z0: Optional[Union[JaxArray, BlockArray]] = None, maxiter: int = 100, - verbose: bool = False, - itstat: Optional[Tuple[dict, Callable]] = None, + itstat_options: Optional[dict] = None, ): r"""Initialize a :class:`PDHG` object. @@ -115,16 +114,16 @@ def __init__( z0: Starting point for :math:`\mb{z}`. If None, defaults to an array of zeros. maxiter: Number of ADMM outer-loop iterations. Default: 100. - verbose: Flag indicating whether iteration statistics should - be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` - is a dict suitable for passing to the `fields` argument - of the :class:`.diagnostics.IterationStats` initializer, - and `insertfunc` is a function with two parameters, an - integer and a PDHG object, responsible for constructing - a tuple ready for insertion into the - :class:`.diagnostics.IterationStats` object. If None, - default values are used for the tuple components. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and an ADMM object, responsible + for constructing a tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If ``None``, + default values are used for the dict entries, otherwise + the default dict is updated with the dict specified by + this parameter. """ self.f: Functional = f self.g: Functional = g @@ -135,48 +134,52 @@ def __init__( self.itnum: int = 0 self.maxiter: int = maxiter self.timer: Timer = Timer() - self.verbose: bool = verbose - - if itstat: - itstat_dict = itstat[0] - itstat_func = itstat[1] - elif itstat is None: - if g.has_eval: - itstat_dict = { - "Iter": "%d", - "Time": "%8.2e", - "Objective": "%8.3e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.itnum, - obj.timer.elapsed(), - obj.objective(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - - else: - itstat_dict = { - "Iter": "%d", - "Time": "%8.1e", - "Primal Rsdl": "%8.3e", - "Dual Rsdl": "%8.3e", - } - - def itstat_func(obj): - return ( - obj.i, - obj.timer.elapsed(), - obj.norm_primal_residual(), - obj.norm_dual_residual(), - ) - - self.itstat_object = IterationStats(itstat_dict, display=verbose) - self.itstat_insert_func = itstat_func + + if g.has_eval: + # The 'g' functions can be evaluated, so objective function can be evaluated + itstat_fields = { + "Iter": "%d", + "Time": "%8.2e", + "Objective": "%8.3e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.objective(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + else: + # The 'g' function can't be evaluated, so drop objective from the default itstat + itstat_fields = { + "Iter": "%d", + "Time": "%8.1e", + "Primal Rsdl": "%8.3e", + "Dual Rsdl": "%8.3e", + } + + def itstat_func(obj): + return ( + obj.itnum, + obj.timer.elapsed(), + obj.norm_primal_residual(), + obj.norm_dual_residual(), + ) + + default_itstat_options = { + "fields": itstat_fields, + "itstat_func": itstat_func, + "display": False, + } + if itstat_options: + default_itstat_options.update(itstat_options) + self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: input_shape = C.input_shape @@ -281,4 +284,5 @@ def solve( self.timer.start() self.timer.stop() self.itnum += 1 + self.itstat_object.end() return self.x diff --git a/scico/test/test_admm.py b/scico/test/test_admm.py index 0176d338e..344580cd8 100644 --- a/scico/test/test_admm.py +++ b/scico/test/test_admm.py @@ -26,7 +26,7 @@ def test_admm(self): g = (self.λ / 2) * functional.BM3D() C = linop.Identity(self.y.shape) - itstat_dict = {"Iter": "%d", "Time": "%8.2e"} + itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) @@ -37,7 +37,7 @@ def itstat_func(obj): C_list=[C], rho_list=[ρ], maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, ) assert len(admm_.itstat_object.fieldname) == 4 assert snp.sum(admm_.x) == 0.0 @@ -47,8 +47,7 @@ def itstat_func(obj): C_list=[C], rho_list=[ρ], maxiter=maxiter, - verbose=False, - itstat=(itstat_dict, itstat_func), + itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(admm_.itstat_object.fieldname) == 2 @@ -95,7 +94,7 @@ def test_admm_generic(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=GenericSubproblemSolver( minimize_kwargs={"options": {"maxiter": 100}} @@ -118,7 +117,7 @@ def test_admm_quadratic_scico(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_function="scico"), ) @@ -139,7 +138,7 @@ def test_admm_quadratic_jax(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_function="jax"), ) @@ -161,7 +160,7 @@ def test_admm_quadratic_relax(self): rho_list=rho_list, alpha=1.6, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_function="jax"), ) @@ -209,7 +208,7 @@ def test_admm_quadratic(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(cg_function="scico"), ) @@ -251,7 +250,7 @@ def test_admm_generic(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=GenericSubproblemSolver( minimize_kwargs={"options": {"maxiter": 100}} @@ -274,7 +273,7 @@ def test_admm_quadratic(self): C_list=C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=A.adj(self.y), subproblem_solver=LinearSubproblemSolver(), ) @@ -310,7 +309,7 @@ def test_admm(self): C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=self.A.adj(self.y), subproblem_solver=LinearSubproblemSolver(), ) @@ -321,7 +320,7 @@ def test_admm(self): C_list=self.C_list, rho_list=rho_list, maxiter=maxiter, - verbose=False, + itstat_options={"display": False}, x0=self.A.adj(self.y), subproblem_solver=CircularConvolveSolver(), ) diff --git a/scico/test/test_diagnostics.py b/scico/test/test_diagnostics.py index edcfd163f..b837b83fb 100644 --- a/scico/test/test_diagnostics.py +++ b/scico/test/test_diagnostics.py @@ -12,3 +12,15 @@ def test_itstat(self): assert its.history()[1].Iter == 1 assert its.history()[1].Objective == 1e2 assert its.history(transpose=True).Objective == [1.5, 100.0] + + def test_display(self, capsys): + its = diagnostics.IterationStats({"Iter": "%d"}, display=True, period=2, overwrite=False) + its.insert((0,)) + cap = capsys.readouterr() + assert cap.out == "Iter\n----\n 0\n" + its.insert((1,)) + cap = capsys.readouterr() + assert cap.out == "" + its.insert((2,)) + cap = capsys.readouterr() + assert cap.out == " 2\n" diff --git a/scico/test/test_ladmm.py b/scico/test/test_ladmm.py index bbca2b660..52f59d474 100644 --- a/scico/test/test_ladmm.py +++ b/scico/test/test_ladmm.py @@ -24,7 +24,7 @@ def test_ladmm(self): g = (self.λ / 2) * functional.BM3D() C = linop.Identity(self.y.shape) - itstat_dict = {"Iter": "%d", "Time": "%8.2e"} + itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) @@ -36,7 +36,6 @@ def itstat_func(obj): mu=μ, nu=ν, maxiter=maxiter, - verbose=False, ) assert len(ladmm_.itstat_object.fieldname) == 4 assert snp.sum(ladmm_.x) == 0.0 @@ -47,8 +46,7 @@ def itstat_func(obj): mu=μ, nu=ν, maxiter=maxiter, - verbose=False, - itstat=(itstat_dict, itstat_func), + itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(ladmm_.itstat_object.fieldname) == 2 @@ -93,7 +91,6 @@ def test_ladmm(self): mu=μ, nu=ν, maxiter=maxiter, - verbose=False, x0=A.adj(self.y), ) x = ladmm_.solve() @@ -133,7 +130,6 @@ def test_ladmm(self): mu=μ, nu=ν, maxiter=maxiter, - verbose=False, x0=A.adj(self.y), ) x = ladmm_.solve() diff --git a/scico/test/test_pdhg.py b/scico/test/test_pdhg.py index 37f6f4077..25bc4dad1 100644 --- a/scico/test/test_pdhg.py +++ b/scico/test/test_pdhg.py @@ -24,7 +24,7 @@ def test_pdhg(self): g = (self.λ / 2) * functional.BM3D() C = linop.Identity(self.y.shape) - itstat_dict = {"Iter": "%d", "Time": "%8.2e"} + itstat_fields = {"Iter": "%d", "Time": "%8.2e"} def itstat_func(obj): return (obj.itnum, obj.timer.elapsed()) @@ -36,7 +36,6 @@ def itstat_func(obj): tau=τ, sigma=σ, maxiter=maxiter, - verbose=False, ) assert len(pdhg_.itstat_object.fieldname) == 4 assert snp.sum(pdhg_.x) == 0.0 @@ -47,8 +46,7 @@ def itstat_func(obj): tau=τ, sigma=σ, maxiter=maxiter, - verbose=False, - itstat=(itstat_dict, itstat_func), + itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, ) assert len(pdhg_.itstat_object.fieldname) == 2 @@ -93,7 +91,6 @@ def test_pdhg(self): tau=τ, sigma=σ, maxiter=maxiter, - verbose=False, x0=A.adj(self.y), ) x = pdhg_.solve() @@ -133,7 +130,6 @@ def test_pdhg(self): tau=τ, sigma=σ, maxiter=maxiter, - verbose=False, x0=A.adj(self.y), ) x = pdhg_.solve() diff --git a/scico/test/test_pgm.py b/scico/test/test_pgm.py index b5dbdf226..770b72920 100644 --- a/scico/test/test_pgm.py +++ b/scico/test/test_pgm.py @@ -37,7 +37,7 @@ def test_pgm(self): L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() - pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, verbose=False, x0=A.adj(self.y)) + pgm_ = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -48,9 +48,7 @@ def test_accelerated_pgm(self): L0 = 1.05 * linop.power_iteration(A.T @ A)[0] loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() - apgm_ = AcceleratedPGM( - f=loss_, g=g, L0=L0, maxiter=maxiter, verbose=False, x0=A.adj(self.y) - ) + apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y)) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -67,7 +65,6 @@ def test_pgm_BB_step_size(self): x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, - verbose=False, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -85,7 +82,6 @@ def test_pgm_adaptive_BB_step_size(self): x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, - verbose=False, ) x = pgm_.solve() @@ -102,7 +98,6 @@ def test_accelerated_pgm_line_search(self): x0=A.adj(self.y), step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55), maxiter=maxiter, - verbose=False, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -120,7 +115,6 @@ def test_accelerated_pgm_robust_line_search(self): x0=A.adj(self.y), step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80), maxiter=maxiter, - verbose=False, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -138,7 +132,6 @@ def test_pgm_BB_step_size_jit(self): x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, - verbose=False, ) x = pgm_.x try: @@ -161,7 +154,6 @@ def test_accelerated_pgm_adaptive_BB_step_size_jit(self): x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, - verbose=False, ) x = apgm_.x try: @@ -201,7 +193,6 @@ def test_pgm(self): L0=L0, x0=A.adj(self.y), maxiter=maxiter, - verbose=False, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -212,9 +203,7 @@ def test_accelerated_pgm(self): L0 = 50.0 loss_ = loss.SquaredL2Loss(y=self.y, A=A) g = (self.λ / 2.0) * functional.SquaredL2Norm() - apgm_ = AcceleratedPGM( - f=loss_, g=g, L0=L0, x0=A.adj(self.y), maxiter=maxiter, verbose=False - ) + apgm_ = AcceleratedPGM(f=loss_, g=g, L0=L0, x0=A.adj(self.y), maxiter=maxiter) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -231,7 +220,6 @@ def test_pgm_BB_step_size(self): x0=A.adj(self.y), step_size=BBStepSize(), maxiter=maxiter, - verbose=False, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -249,7 +237,6 @@ def test_pgm_adaptive_BB_step_size(self): x0=A.adj(self.y), step_size=AdaptiveBBStepSize(), maxiter=maxiter, - verbose=False, ) x = pgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -267,7 +254,6 @@ def test_accelerated_pgm_line_search(self): x0=A.adj(self.y), step_size=LineSearchStepSize(gamma_u=1.03, maxiter=55), maxiter=maxiter, - verbose=False, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3) @@ -285,7 +271,6 @@ def test_accelerated_pgm_robust_line_search(self): x0=A.adj(self.y), step_size=RobustLineSearchStepSize(gamma_d=0.95, gamma_u=1.05, maxiter=80), maxiter=maxiter, - verbose=False, ) x = apgm_.solve() np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)