Skip to content

Commit

Permalink
Fix bug and add example (#246)
Browse files Browse the repository at this point in the history
Fixes bug in scico.svmbir that caused errors with the new release of svmbir
Fixes issue #198 by adding an example that uses the svmbir extended losss

Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
smajee and bwohlberg authored Mar 8, 2022
1 parent 6a481de commit 13be7a3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 24 deletions.
2 changes: 1 addition & 1 deletion data
119 changes: 98 additions & 21 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
:cite:`dabov-2008-image` as a denoiser and SVMBIR :cite:`svmbir-2020` for
tomographic projection.
This version uses the data fidelity term as one of the ADMM g functionals,
and thus the optimization with respect to the data fidelity is able to
exploit the internal prox of the SVMBIRSquaredL2Loss functional.
This version uses the data fidelity term as one of the ADMM $g$
functionals so that the optimization with respect to the data fidelity is
able to exploit the internal prox of the `SVMBIRExtendedLoss` and
`SVMBIRSquaredL2Loss` functionals.
We solve the problem in two different ways:
1. Using the `SVMBIRSquaredL2Loss` together with the BM3D pseudo-functional
and a non-negative indicator function, and
2. Using the `SVMBIRExtendedLoss`, which includes a non-negativity
constraint, together with the BM3D pseudo-functional.
"""

import numpy as np
Expand All @@ -26,13 +33,18 @@

import matplotlib.pyplot as plt
import svmbir
from matplotlib.ticker import MaxNLocator
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss
from scico.linop.radon_svmbir import (
ParallelBeamProjector,
SVMBIRExtendedLoss,
SVMBIRSquaredL2Loss,
)
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand Down Expand Up @@ -85,22 +97,30 @@


"""
Set up an ADMM solver.
Push arrays to device.
"""
y, x0, weights = jax.device_put([y, x_mrf, weights])


"""
Set problem parameters and BM3D pseudo-functional.
"""
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma
g0 = σ * ρ * BM3D()

f = SVMBIRSquaredL2Loss(

"""
Set up problem using `SVMBIRSquaredL2Loss` and `NonNegativeIndicator`.
"""
f_l2loss = SVMBIRSquaredL2Loss(
y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0}
)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

solver = ADMM(
solver_l2loss = ADMM(
f=None,
g_list=[f, g0, g1],
g_list=[f_l2loss, g0, g1],
C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)],
rho_list=[ρ, ρ, ρ],
x0=x0,
Expand All @@ -111,33 +131,73 @@


"""
Run the solver.
Run the ADMM solver.
"""
print(f"Solving on {device_info()}\n")
x_bm3d = solver.solve()
hist = solver.itstat_object.history(transpose=True)
x_l2loss = solver_l2loss.solve()
hist_l2loss = solver_l2loss.itstat_object.history(transpose=True)


"""
Set up problem using `SVMBIRExtendedLoss`, without need for `NonNegativeIndicator`.
"""
f_extloss = SVMBIRExtendedLoss(
y=y,
A=A,
W=Diagonal(weights),
scale=0.5,
positivity=True,
prox_kwargs={"maxiter": 5, "ctol": 0.0},
)

solver_extloss = ADMM(
f=None,
g_list=[f_extloss, g0],
C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)],
rho_list=[ρ, ρ],
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
)


"""
Run the ADMM solver.
"""
print()
x_extloss = solver_extloss.solve()
hist_extloss = solver_extloss.itstat_object.history(transpose=True)


"""
Show the recovered image.
Show the recovered images.
"""
norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)
fig, ax = plt.subplots(1, 3, figsize=[15, 5])
plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0], norm=norm)
fig, ax = plt.subplots(2, 2, figsize=(15, 15))
plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0, 0], norm=norm)
plot.imview(
img=x_mrf,
title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[1],
ax=ax[0, 1],
norm=norm,
)
plot.imview(
img=x_bm3d,
title=f"BM3D (PSNR: {metric.psnr(x_gt, x_bm3d):.2f} dB)",
img=x_l2loss,
title=f"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[2],
ax=ax[1, 0],
norm=norm,
)
plot.imview(
img=x_extloss,
title=f"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[1, 1],
norm=norm,
)
fig.show()
Expand All @@ -146,13 +206,30 @@
"""
Plot convergence statistics.
"""
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
plot.plot(
snp.vstack((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T,
ptyp="semilogy",
title="Residuals (SquaredL2Loss + non-negativity)",
xlbl="Iteration",
lgnd=("Primal", "Dual"),
fig=fig,
ax=ax[0],
)
ax[0].set_ylim([5e-3, 1e0])
ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
plot.plot(
snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,
ptyp="semilogy",
title="Residuals",
title="Residuals (ExtendedLoss)",
xlbl="Iteration",
lgnd=("Primal", "Dual"),
fig=fig,
ax=ax[1],
)
ax[1].set_ylim([5e-3, 1e0])
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
fig.show()


input("\nWaiting for input to close figures and exit")
4 changes: 2 additions & 2 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def _bproj(
svmbir.backproject(
np.array(y),
np.array(angles),
num_rows,
num_cols,
num_rows=num_rows,
num_cols=num_cols,
verbose=0,
center_offset=center_offset,
roi_radius=roi_radius,
Expand Down

0 comments on commit 13be7a3

Please sign in to comment.