diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index 82d61e3c5..8f278ec3a 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -57,7 +57,8 @@ λ = 2e-0 # L1 norm regularization parameter ρ = 5e-0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations -num_inner_iter = 20 # number of CG iterations per ADMM iteration +cg_tol = 1e-4 # CG relative tolerance +cg_maxiter = 25 # maximum CG iterations per ADMM iteration g = λ * functional.L1Norm() # regularization functionals gi C = linop.FiniteDifference(input_shape=x_gt.shape) # analysis operators Ci @@ -73,7 +74,7 @@ rho_list=[ρ], x0=x0, maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": num_inner_iter}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 5}, ) diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index 187ebfcdb..180b487d2 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -106,7 +106,8 @@ def postprocess(x): lambda_unweighted = 2.56e2 # regularization strength maxiter = 50 # number of ADMM iterations -max_inner_iter = 10 # number of CG iterations per ADMM iteration +cg_tol = 1e-5 # CG relative tolerance +cg_maxiter = 10 # maximum CG iterations per ADMM iteration f = loss.SquaredL2Loss(y=y, A=A) @@ -117,7 +118,7 @@ def postprocess(x): rho_list=[ρ], x0=x0, maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") @@ -147,7 +148,7 @@ def postprocess(x): rho_list=[ρ], maxiter=maxiter, x0=x0, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) admm_weighted.solve() diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index dc6bfef74..e56021161 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -103,7 +103,7 @@ rho_list=[ρ, ρ], x0=x0, maxiter=20, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 100}), itstat_options={"display": True, "period": 1}, ) diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 03103bc88..b19253d7c 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -105,7 +105,7 @@ rho_list=[ρ, ρ, ρ], x0=x0, maxiter=20, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index f213b3419..353405371 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -102,7 +102,7 @@ rho_list=[2e1], x0=x0, maxiter=50, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 10}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 10}), itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") diff --git a/examples/scripts/deconv_ppp_bm3d_admm.py b/examples/scripts/deconv_ppp_bm3d_admm.py index 4103cebec..016884543 100644 --- a/examples/scripts/deconv_ppp_bm3d_admm.py +++ b/examples/scripts/deconv_ppp_bm3d_admm.py @@ -68,7 +68,7 @@ rho_list=[ρ], x0=A.T @ y, maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) diff --git a/examples/scripts/deconv_ppp_dncnn_admm.py b/examples/scripts/deconv_ppp_dncnn_admm.py index daaae937f..f98d1db18 100644 --- a/examples/scripts/deconv_ppp_dncnn_admm.py +++ b/examples/scripts/deconv_ppp_dncnn_admm.py @@ -73,7 +73,7 @@ rho_list=[ρ], x0=A.T @ y, maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 30}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 30}), itstat_options={"display": True}, ) diff --git a/examples/scripts/demosaic_ppp_bm3d_admm.py b/examples/scripts/demosaic_ppp_bm3d_admm.py index c670408b2..82b52a026 100644 --- a/examples/scripts/demosaic_ppp_bm3d_admm.py +++ b/examples/scripts/demosaic_ppp_bm3d_admm.py @@ -114,7 +114,7 @@ def demosaic(cfaimg): rho_list=[ρ], x0=imgb, maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), itstat_options={"display": True}, ) diff --git a/examples/scripts/denoise_tv_iso_admm.py b/examples/scripts/denoise_tv_iso_admm.py index cc5a45563..254446f78 100644 --- a/examples/scripts/denoise_tv_iso_admm.py +++ b/examples/scripts/denoise_tv_iso_admm.py @@ -69,7 +69,7 @@ rho_list=[1e1], x0=y, maxiter=100, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), itstat_options={"display": True, "period": 10}, ) @@ -92,7 +92,7 @@ rho_list=[1e1], x0=y, maxiter=100, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}), itstat_options={"display": True, "period": 10}, ) diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index d8cea0f0f..0eac37b6a 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -158,25 +158,32 @@ class LinearSubproblemSolver(SubproblemSolver): :math:`\mb{x}` update step. """ - def __init__(self, cg_kwargs: dict = {"maxiter": 100}, cg_function: str = "scico"): + def __init__(self, cg_kwargs: Optional[dict] = None, cg_function: str = "scico"): """Initialize a :class:`LinearSubproblemSolver` object. Args: cg_kwargs: Dictionary of arguments for CG solver. See - :func:`scico.solver.cg` or - :func:`jax.scipy.sparse.linalg.cg`, documentation, + documentation for :func:`scico.solver.cg` or + :func:`jax.scipy.sparse.linalg.cg`, including how to specify a preconditioner. + Default values are the same as those of + :func:`scico.solver.cg`, except for + ``"tol": 1e-4`` and ``"maxiter": 100``. cg_function: String indicating which CG implementation to use. One of "jax" or "scico"; default "scico". If "scico", uses :func:`scico.solver.cg`. If "jax", uses - :func:`jax.scipy.sparse.linalg.cg`. The "jax" option is + :func:`jax.scipy.sparse.linalg.cg`. The "jax" option is slower on small-scale problems or problems involving external functions, but can be differentiated through. The "scico" option is faster on small-scale problems, but slower on large-scale problems where the forward operator is written entirely in jax. """ - self.cg_kwargs = cg_kwargs + + default_cg_kwargs = {"tol": 1e-4, "maxiter": 100} + if cg_kwargs: + default_cg_kwargs.update(cg_kwargs) + self.cg_kwargs = default_cg_kwargs if cg_function == "scico": self.cg = scico_cg elif cg_function == "jax": diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index ab3327868..44305daae 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -120,10 +120,10 @@ def test_admm_quadratic_scico(self): maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), - subproblem_solver=LinearSubproblemSolver(cg_function="scico"), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_jax(self): maxiter = 50 @@ -141,10 +141,10 @@ def test_admm_quadratic_jax(self): maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), - subproblem_solver=LinearSubproblemSolver(cg_function="jax"), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 def test_admm_quadratic_relax(self): maxiter = 50 @@ -163,10 +163,10 @@ def test_admm_quadratic_relax(self): maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), - subproblem_solver=LinearSubproblemSolver(cg_function="jax"), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestRealWeighted: @@ -211,10 +211,10 @@ def test_admm_quadratic(self): maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), - subproblem_solver=LinearSubproblemSolver(cg_function="scico"), + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestComplex: @@ -258,7 +258,7 @@ def test_admm_generic(self): ), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 2e-4 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3 def test_admm_quadratic(self): maxiter = 50 @@ -276,10 +276,12 @@ def test_admm_quadratic(self): maxiter=maxiter, itstat_options={"display": False}, x0=A.adj(self.y), - subproblem_solver=LinearSubproblemSolver(), + subproblem_solver=LinearSubproblemSolver( + cg_kwargs={"tol": 1e-4}, + ), ) x = admm_.solve() - assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5 + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 class TestCircularConvolveSolve: