diff --git a/data b/data index d7b0478e3..3a92b5fa4 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit d7b0478e3769f0cb8ec72b7cf936eb2a997ee8f2 +Subproject commit 3a92b5fa46fad6c04ad452ed986ac861013a4201 diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 2131ca6f7..4617fef8f 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -15,9 +15,16 @@ :cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for tomographic projection. -This version uses the data fidelity term as one of the ADMM g functionals, -and thus the optimization with respect to the data fidelity is able to -exploit the internal prox of the SVMBIRSquaredL2Loss functional. +This version uses the data fidelity term as one of the ADMM $g$ +functionals so that the optimization with respect to the data fidelity is +able to exploit the internal prox of the `SVMBIRExtendedLoss` and +`SVMBIRSquaredL2Loss` functionals. + +We solve the problem in two different ways: +1. Using the `SVMBIRSquaredL2Loss` together with the BM3D pseudo-functional + and a non-negative indicator function, and +2. Using the `SVMBIRExtendedLoss`, which includes a non-negativity + constraint, together with the BM3D pseudo-functional. """ import numpy as np @@ -26,13 +33,18 @@ import matplotlib.pyplot as plt import svmbir +from matplotlib.ticker import MaxNLocator from xdesign import Foam, discrete_phantom import scico.numpy as snp from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss +from scico.linop.radon_svmbir import ( + ParallelBeamProjector, + SVMBIRExtendedLoss, + SVMBIRSquaredL2Loss, +) from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -85,22 +97,30 @@ """ -Set up an ADMM solver. +Push arrays to device. """ y, x0, weights = jax.device_put([y, x_mrf, weights]) + +""" +Set problem parameters and BM3D pseudo-functional. +""" ρ = 10 # ADMM penalty parameter σ = density * 0.26 # denoiser sigma +g0 = σ * ρ * BM3D() -f = SVMBIRSquaredL2Loss( + +""" +Set up problem using `SVMBIRSquaredL2Loss` and `NonNegativeIndicator`. +""" +f_l2loss = SVMBIRSquaredL2Loss( y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0} ) -g0 = σ * ρ * BM3D() g1 = NonNegativeIndicator() -solver = ADMM( +solver_l2loss = ADMM( f=None, - g_list=[f, g0, g1], + g_list=[f_l2loss, g0, g1], C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)], rho_list=[ρ, ρ, ρ], x0=x0, @@ -111,33 +131,73 @@ """ -Run the solver. +Run the ADMM solver. """ print(f"Solving on {device_info()}\n") -x_bm3d = solver.solve() -hist = solver.itstat_object.history(transpose=True) +x_l2loss = solver_l2loss.solve() +hist_l2loss = solver_l2loss.itstat_object.history(transpose=True) + + +""" +Set up problem using `SVMBIRExtendedLoss`, without need for `NonNegativeIndicator`. +""" +f_extloss = SVMBIRExtendedLoss( + y=y, + A=A, + W=Diagonal(weights), + scale=0.5, + positivity=True, + prox_kwargs={"maxiter": 5, "ctol": 0.0}, +) + +solver_extloss = ADMM( + f=None, + g_list=[f_extloss, g0], + C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)], + rho_list=[ρ, ρ], + x0=x0, + maxiter=20, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}), + itstat_options={"display": True}, +) + + +""" +Run the ADMM solver. +""" +print() +x_extloss = solver_extloss.solve() +hist_extloss = solver_extloss.itstat_object.history(transpose=True) """ -Show the recovered image. +Show the recovered images. """ norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density) -fig, ax = plt.subplots(1, 3, figsize=[15, 5]) -plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm) +fig, ax = plt.subplots(2, 2, figsize=(15, 15)) +plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0, 0], norm=norm) plot.imview( img=x_mrf, title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)", cbar=True, fig=fig, - ax=ax[1], + ax=ax[0, 1], norm=norm, ) plot.imview( - img=x_bm3d, - title=f"BM3D (PSNR: {metric.psnr(x_gt, x_bm3d):.2f} dB)", + img=x_l2loss, + title=f"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)", cbar=True, fig=fig, - ax=ax[2], + ax=ax[1, 0], + norm=norm, +) +plot.imview( + img=x_extloss, + title=f"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)", + cbar=True, + fig=fig, + ax=ax[1, 1], norm=norm, ) fig.show() @@ -146,13 +206,30 @@ """ Plot convergence statistics. """ +fig, ax = plt.subplots(1, 2, figsize=(15, 5)) +plot.plot( + snp.vstack((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T, + ptyp="semilogy", + title="Residuals (SquaredL2Loss + non-negativity)", + xlbl="Iteration", + lgnd=("Primal", "Dual"), + fig=fig, + ax=ax[0], +) +ax[0].set_ylim([5e-3, 1e0]) +ax[0].xaxis.set_major_locator(MaxNLocator(integer=True)) plot.plot( - snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T, + snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T, ptyp="semilogy", - title="Residuals", + title="Residuals (ExtendedLoss)", xlbl="Iteration", lgnd=("Primal", "Dual"), + fig=fig, + ax=ax[1], ) +ax[1].set_ylim([5e-3, 1e0]) +ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) +fig.show() input("\nWaiting for input to close figures and exit") diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 0f0eed64c..551627134 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -151,8 +151,8 @@ def _bproj( svmbir.backproject( np.array(y), np.array(angles), - num_rows, - num_cols, + num_rows=num_rows, + num_cols=num_cols, verbose=0, center_offset=center_offset, roi_radius=roi_radius,