Skip to content

Commit

Permalink
Improve iteration statistics mechanism (#130)
Browse files Browse the repository at this point in the history
* Various cleanup, mainly docstrings

* Docstring cleanup

* Typo bug fix

* Use intersphinx refs

* Some improvements

* Remove explicit copy of jax docstrings

* Clean up docstrings

* Fix indentation issues

* Clean up docstrings

* Minor edit

* Clean up docstrings

* Style guide compliance

* Docstring cleanup

* Style guide compliance

* Cleanup

* Fix docstring format problem

* Cleanup and style compliance

* Apply black manually

* Docstring cleanup and style compliance

* Add subsampling of displayed statistics

* Minor edits

* Add optional overwrite and handling for end of iterations

* Minor docstring edits

* Change mechanism for specifying itstat options in optimizer classes

* Add missing itstat end calls

* Clean up output

* Add a test

* Change default display period value
  • Loading branch information
bwohlberg authored Dec 14, 2021
1 parent ddeee4e commit 4e88014
Show file tree
Hide file tree
Showing 31 changed files with 323 additions and 289 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": num_inner_iter}),
verbose=True,
itstat_options={"display": True, period: 5},
)


Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def postprocess(x):
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}),
verbose=True,
itstat_options={"display": True, period: 10},
)
print(f"Solving on {device_info()}\n")
admm_unweighted.solve()
Expand Down Expand Up @@ -148,7 +148,7 @@ def postprocess(x):
maxiter=maxiter,
x0=x0,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": max_inner_iter}),
verbose=True,
itstat_options={"display": True, period: 10},
)
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
verbose=True,
itstat_options={"display": True, period: 1},
)


Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
verbose=True,
itstat_options={"display": True},
)


Expand Down
12 changes: 6 additions & 6 deletions examples/scripts/ct_svmbir_tv_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@
x0=x0,
maxiter=50,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 10}),
verbose=True,
itstat_options={"display": True, period: 10},
)
print(f"Solving on {device_info()}\n")
x_admm = solve_admm.solve()
hist_admm = solve_admm.itstat_object.history(transpose=True)
print(metric.psnr(x_gt, x_admm))
print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n")


"""
Expand All @@ -123,11 +123,11 @@
nu=2e-1,
x0=x0,
maxiter=50,
verbose=True,
itstat_options={"display": True, period: 10},
)
x_ladmm = solver_ladmm.solve()
hist_ladmm = solver_ladmm.itstat_object.history(transpose=True)
print(metric.psnr(x_gt, x_ladmm))
print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n")


"""
Expand All @@ -141,11 +141,11 @@
sigma=8e0,
x0=x0,
maxiter=50,
verbose=True,
itstat_options={"display": True, period: 10},
)
x_pdhg = solver_pdhg.solve()
hist_pdhg = solver_pdhg.itstat_object.history(transpose=True)
print(metric.psnr(x_gt, x_pdhg))
print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n")


"""
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_circ_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=CircularConvolveSolver(),
verbose=True,
itstat_options={"display": True, period: 10},
)


Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/deconv_microscopy_allchn_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,16 @@ def deconvolve_channel(channel):
g2 = functional.NonNegativeIndicator() # non-negativity constraint
if channel == 0:
print("Displaying solver status for channel 0")
verbose = True
display = True
else:
verbose = False
display = False
solver = ADMM(
f=None,
g_list=[g0, g1, g2],
C_list=[C0, C1, C2],
rho_list=[ρ0, ρ1, ρ2],
maxiter=maxiter,
verbose=verbose,
itstat_options={"display": display, period: 10},
x0=y_pad,
subproblem_solver=CircularConvolveSolver(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_microscopy_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def block_avg(im, N):
C_list=[C0, C1, C2],
rho_list=[ρ0, ρ1, ρ2],
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
x0=y_pad,
subproblem_solver=CircularConvolveSolver(),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_ppp_bm3d_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
x0=A.T @ y,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
verbose=True,
itstat_options={"display": True},
)


Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/deconv_ppp_bm3d_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@

maxiter = 50 # number of APGM iterations

solver = AcceleratedPGM(f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, verbose=True)
solver = AcceleratedPGM(
f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True, "period": 10}
)

print(f"Solving on {device_info()}\n")
x = solver.solve()
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_ppp_dncnn_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
x0=A.T @ y,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 30}),
verbose=True,
itstat_options={"display": True},
)


Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(),
verbose=True,
itstat_options={"display": True, period: 10},
)


Expand Down
1 change: 0 additions & 1 deletion examples/scripts/deconv_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def eval_params(config):
x0=A.adj(y),
maxiter=5,
subproblem_solver=LinearSubproblemSolver(),
verbose=False,
)
# Perform 50 iterations, reporting performance to ray.tune every 5 iterations.
for step in range(10):
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/demosaic_ppp_bm3d_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def demosaic(cfaimg):
x0=imgb,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 100}),
verbose=True,
itstat_options={"display": True},
)


Expand Down
7 changes: 4 additions & 3 deletions examples/scripts/denoise_tv_iso_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@
x0=y,
maxiter=100,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}),
verbose=True,
itstat_options={"display": True, period: 10},
)

print(f"Solving on {device_info()}\n")
solver.solve()
x_iso = solver.x

print()

"""
Denoise with anisotropic total variation for comparison.
Expand All @@ -93,11 +93,12 @@
x0=y,
maxiter=100,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 20}),
verbose=True,
itstat_options={"display": True, period: 10},
)

solver.solve()
x_aniso = solver.x
print()


"""
Expand Down
7 changes: 3 additions & 4 deletions examples/scripts/denoise_tv_iso_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
x0=y,
maxiter=1,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 1}),
verbose=False,
)
solver_admm.solve()
# trailing semi-colon suppresses output in notebook
Expand All @@ -88,7 +87,7 @@
x0=y,
maxiter=200,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"maxiter": 2}),
verbose=True,
itstat_options={"display": True, period: 10},
)
print(f"Solving on {device_info()}\n")
solver_admm.solve()
Expand All @@ -106,7 +105,7 @@
nu=1e-1,
x0=y,
maxiter=200,
verbose=True,
itstat_options={"display": True, period: 10},
)
solver_ladmm.solve()
hist_ladmm = solver_ladmm.itstat_object.history(transpose=True)
Expand All @@ -122,7 +121,7 @@
tau=4e-1,
sigma=4e-1,
maxiter=200,
verbose=True,
itstat_options={"display": True, period: 10},
)
solver_pdhg.solve()
hist_pdhg = solver_pdhg.itstat_object.history(transpose=True)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray:
L0=16.0 * f_iso.lmbda ** 2,
x0=x0,
maxiter=100,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=RobustLineSearchStepSize(),
)

Expand Down Expand Up @@ -194,7 +194,7 @@ def prox(self, x: JaxArray, lam: float, **kwargs) -> JaxArray:
L0=16.0 * f.lmbda ** 2,
x0=x0,
maxiter=100,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=RobustLineSearchStepSize(),
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sparsecode_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(),
verbose=True,
itstat_options={"display": True, period: 10},
)


Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/sparsecode_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
A = linop.MatrixOperator(D)
f = loss.SquaredL2Loss(y=y, A=A)
g = λ * functional.L1Norm()
solver = AcceleratedPGM(f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, verbose=True)
solver = AcceleratedPGM(
f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={"display": True, "period": 10}
)


"""
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/sparsecode_poisson_blkarr_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
)
str_ss = type(solver.step_size).__name__

Expand Down Expand Up @@ -216,7 +216,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=BBStepSize(),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -244,7 +244,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=AdaptiveBBStepSize(kappa=0.75),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -272,7 +272,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=LineSearchStepSize(),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -300,7 +300,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Aop):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=RobustLineSearchStepSize(),
)
str_ss = type(solver.step_size).__name__
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/sparsecode_poisson_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
)
str_ss = type(solver.step_size).__name__

Expand Down Expand Up @@ -187,7 +187,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=BBStepSize(),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -215,7 +215,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=AdaptiveBBStepSize(kappa=0.75),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -243,7 +243,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=LineSearchStepSize(gamma_u=1.01),
)
str_ss = type(solver.step_size).__name__
Expand Down Expand Up @@ -271,7 +271,7 @@ def plot_results(hist, str_ss, L0, xsol, xgt, Amat):
L0=L0,
x0=x0,
maxiter=maxiter,
verbose=True,
itstat_options={"display": True, period: 10},
step_size=RobustLineSearchStepSize(),
)
str_ss = type(solver.step_size).__name__
Expand Down
Loading

0 comments on commit 4e88014

Please sign in to comment.