Skip to content

Commit

Permalink
Modify cg_kwargs mechanism in admm.LinearSubproblemSolver (#140)
Browse files Browse the repository at this point in the history
* Change cg_kwargs mechanism so that default values are preserved unless explicitly changed

* Change default tolerance

* Change cg_kwargs in example scripts

* Resolve test failures

* Include default values in docstring
  • Loading branch information
bwohlberg authored Dec 18, 2021
1 parent fe2a50e commit fc2852a
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 29 deletions.
5 changes: 3 additions & 2 deletions examples/scripts/ct_astra_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
λ = 2e-0 # L1 norm regularization parameter
ρ = 5e-0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
num_inner_iter = 20 # number of CG iterations per ADMM iteration
cg_tol = 1e-4 # CG relative tolerance
cg_maxiter = 25 # maximum CG iterations per ADMM iteration

g = λ * functional.L1Norm() # regularization functionals gi
C = linop.FiniteDifference(input_shape=x_gt.shape) # analysis operators Ci
Expand All @@ -73,7 +74,7 @@
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": num_inner_iter}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 5},
)

Expand Down
7 changes: 4 additions & 3 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def postprocess(x):
lambda_unweighted = 2.56e2 # regularization strength

maxiter = 50 # number of ADMM iterations
max_inner_iter = 10 # number of CG iterations per ADMM iteration
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration

f = loss.SquaredL2Loss(y=y, A=A)

Expand All @@ -117,7 +118,7 @@ def postprocess(x):
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
Expand Down Expand Up @@ -147,7 +148,7 @@ def postprocess(x):
rho_list=[ρ],
maxiter=maxiter,
x0=x0,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
admm_weighted.solve()
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
rho_list=[ρ, ρ],
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 100}),
itstat_options={"display": True, "period": 1},
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
rho_list=[ρ, ρ, ρ],
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_svmbir_tv_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
rho_list=[2e1],
x0=x0,
maxiter=50,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 10}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4, "maxiter": 10}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_ppp_bm3d_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
rho_list=[ρ],
x0=A.T @ y,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_ppp_dncnn_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
rho_list=[ρ],
x0=A.T @ y,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 30}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 30}),
itstat_options={"display": True},
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/demosaic_ppp_bm3d_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def demosaic(cfaimg):
rho_list=[ρ],
x0=imgb,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
)

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/denoise_tv_iso_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
rho_list=[1e1],
x0=y,
maxiter=100,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
itstat_options={"display": True, "period": 10},
)

Expand All @@ -92,7 +92,7 @@
rho_list=[1e1],
x0=y,
maxiter=100,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
itstat_options={"display": True, "period": 10},
)

Expand Down
17 changes: 12 additions & 5 deletions scico/optimize/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,32 @@ class LinearSubproblemSolver(SubproblemSolver):
:math:`\mb{x}` update step.
"""

def __init__(self, cg_kwargs: dict = {"maxiter": 100}, cg_function: str = "scico"):
def __init__(self, cg_kwargs: Optional[dict] = None, cg_function: str = "scico"):
"""Initialize a :class:`LinearSubproblemSolver` object.
Args:
cg_kwargs: Dictionary of arguments for CG solver. See
:func:`scico.solver.cg` or
:func:`jax.scipy.sparse.linalg.cg`, documentation,
documentation for :func:`scico.solver.cg` or
:func:`jax.scipy.sparse.linalg.cg`,
including how to specify a preconditioner.
Default values are the same as those of
:func:`scico.solver.cg`, except for
``"tol": 1e-4`` and ``"maxiter": 100``.
cg_function: String indicating which CG implementation to
use. One of "jax" or "scico"; default "scico". If
"scico", uses :func:`scico.solver.cg`. If "jax", uses
:func:`jax.scipy.sparse.linalg.cg`. The "jax" option is
:func:`jax.scipy.sparse.linalg.cg`. The "jax" option is
slower on small-scale problems or problems involving
external functions, but can be differentiated through.
The "scico" option is faster on small-scale problems, but
slower on large-scale problems where the forward
operator is written entirely in jax.
"""
self.cg_kwargs = cg_kwargs

default_cg_kwargs = {"tol": 1e-4, "maxiter": 100}
if cg_kwargs:
default_cg_kwargs.update(cg_kwargs)
self.cg_kwargs = default_cg_kwargs
if cg_function == "scico":
self.cg = scico_cg
elif cg_function == "jax":
Expand Down
24 changes: 13 additions & 11 deletions scico/test/optimize/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def test_admm_quadratic_scico(self):
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(cg_function="scico"),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4

def test_admm_quadratic_jax(self):
maxiter = 50
Expand All @@ -141,10 +141,10 @@ def test_admm_quadratic_jax(self):
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(cg_function="jax"),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4

def test_admm_quadratic_relax(self):
maxiter = 50
Expand All @@ -163,10 +163,10 @@ def test_admm_quadratic_relax(self):
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(cg_function="jax"),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="jax"),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4


class TestRealWeighted:
Expand Down Expand Up @@ -211,10 +211,10 @@ def test_admm_quadratic(self):
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(cg_function="scico"),
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-4}, cg_function="scico"),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4


class TestComplex:
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_admm_generic(self):
),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 2e-4
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-3

def test_admm_quadratic(self):
maxiter = 50
Expand All @@ -276,10 +276,12 @@ def test_admm_quadratic(self):
maxiter=maxiter,
itstat_options={"display": False},
x0=A.adj(self.y),
subproblem_solver=LinearSubproblemSolver(),
subproblem_solver=LinearSubproblemSolver(
cg_kwargs={"tol": 1e-4},
),
)
x = admm_.solve()
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4


class TestCircularConvolveSolve:
Expand Down

0 comments on commit fc2852a

Please sign in to comment.