Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PDHG solver with support for non-linear operators #322

Merged
merged 20 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ Version 0.0.3 (unreleased)
• Change filenames of some example scripts (and corresponding notebooks).
• Change required packages and version numbers.
• Add support for Python 3.7.
• Add ``DiagonalStack`` linear operator.
• New ``DiagonalStack`` linear operator.
• Add support for non-linear operators to ``optimize.PDHG`` optimizer class.
• Various bug fixes.


Expand Down
3 changes: 3 additions & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Miscellaneous
:maxdepth: 1

examples/demosaic_ppp_bm3d_admm
examples/denoise_cplx_tv_pdhg
examples/denoise_l1tv_admm
examples/denoise_tv_admm
examples/denoise_tv_pgm
Expand Down Expand Up @@ -114,6 +115,7 @@ Total Variation
examples/deconv_microscopy_allchn_tv_admm
examples/deconv_tv_admm
examples/deconv_tv_admm_tune
examples/denoise_cplx_tv_pdhg
examples/denoise_l1tv_admm
examples/denoise_tv_admm
examples/denoise_tv_pgm
Expand Down Expand Up @@ -186,6 +188,7 @@ PDHG
:maxdepth: 1

examples/ct_svmbir_tv_multi
examples/denoise_cplx_tv_pdhg
examples/denoise_tv_multi


Expand Down
12 changes: 12 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,18 @@ @Misc {svmbir-2020
year = 2020
}

@Article {valkonen-2014-primal,
title = {A primal--dual hybrid gradient method for nonlinear
operators with applications to {MRI}},
author = {Valkonen, Tuomo},
journal = {Inverse Problems},
volume = 30,
number = 5,
pages = {055012},
year = 2014,
doi = {10.1088/0266-5611/30/5/055012}
}

@InProceedings {venkatakrishnan-2013-plugandplay2,
author = {Singanallur V. Venkatakrishnan and Charles A. Bouman
and Brendt Wohlberg},
Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ Miscellaneous

`demosaic_ppp_bm3d_admm.py <demosaic_ppp_bm3d_admm.py>`_
Image Demosaicing (ADMM Plug-and-Play Priors w/ BM3D)
`denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_
Complex Total Variation Denoising (ADMM)
`denoise_l1tv_admm.py <denoise_l1tv_admm.py>`_
ℓ1 Total Variation (ADMM)
`denoise_tv_admm.py <denoise_tv_admm.py>`_
Expand Down Expand Up @@ -128,6 +130,8 @@ Total Variation
Image Deconvolution (ADMM w/ Total Variation)
`deconv_tv_admm_tune.py <deconv_tv_admm_tune.py>`_
Image Deconvolution Parameter Tuning
`denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_
Complex Total Variation Denoising (ADMM)
`denoise_l1tv_admm.py <denoise_l1tv_admm.py>`_
ℓ1 Total Variation (ADMM)
`denoise_tv_admm.py <denoise_tv_admm.py>`_
Expand Down Expand Up @@ -218,6 +222,8 @@ PDHG

`ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_
CT Reconstruction with TV Regularization
`denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_
Complex Total Variation Denoising (ADMM)
`denoise_tv_multi.py <denoise_tv_multi.py>`_
Comparison of Optimization Algorithms for Total Variation Denoising

Expand Down
220 changes: 220 additions & 0 deletions examples/scripts/denoise_cplx_tv_pdhg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

r"""
Complex Total Variation Denoising (ADMM)
========================================

This example demonstrates solution of a problem of the form

$$\argmin_{\mathbf{x}} \; f(\mathbf{x}) + g(C(\mathbf{x})) \;,$$

where $C$ is a nonlinear operator, via non-linear PDHG
:cite:`valkonen-2014-primal`. The example problem represents total
variation (TV) denoising applied to a complex image with piece-wise
smooth magnitude and non-smooth phase. The appropriate TV denoising
formulation for this problem is

$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x}
\|_2^2 + \lambda \| C(\mathbf{x}) \|_{2,1} \;,$$

where $\mathbf{y}$ is the measurement, $\|\cdot\|_{2,1}$ is the
$\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator that applies a
linear difference operator to the magnitude of a complex array. The
standard TV solution, which is also computed for comparison purposes,
gives very poor results since the difference is applied independently to
real and imaginary components of the complex image.
"""

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
from xdesign import SiemensStar, discrete_phantom

import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric, operator, plot
from scico.examples import phase_diff
from scico.optimize import PDHG
from scico.util import device_info

"""
Create a ground truth image.
"""
N = 256 # image size
phantom = SiemensStar(16)
x_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0
x_mag /= x_mag.max()
# Create reference image with structured magnitude and random phase
x_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0])


"""
Add noise to create a noisy test image.
"""
σ = 0.25 # noise standard deviation
noise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64)
y = x_gt + σ * noise


"""
Denoise with standard total variation.
"""
λ_tv = 6e-2
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
f = loss.SquaredL2Loss(y=y)
g = λ_tv * functional.L21Norm()
# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.complex64, append=0)
solver_tv = PDHG(
f=f,
g=g,
C=C,
tau=4e-1,
sigma=4e-1,
maxiter=200,
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
x_tv = solver_tv.solve()
hist_tv = solver_tv.itstat_object.history(transpose=True)


"""
Denoise with non-linear total variation.
"""
λ_nltv = 2e-1
g = λ_nltv * functional.L21Norm()
# Redefine C for real input (now applied to magnitude of a complex array)
C = linop.FiniteDifference(input_shape=x_gt.shape, input_dtype=snp.float32, append=0)
# Operator computing differences of absolute values
D = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64)
solver_nltv = PDHG(
f=f,
g=g,
C=D,
tau=4e-1,
sigma=4e-1,
maxiter=200,
itstat_options={"display": True, "period": 10},
)
x_nltv = solver_nltv.solve()
hist_nltv = solver_nltv.itstat_object.history(transpose=True)


"""
Plot results.
"""
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6))
plot.plot(
snp.vstack((hist_tv.Objective, hist_nltv.Objective)).T,
ptyp="semilogy",
title="Objective function",
xlbl="Iteration",
lgnd=("PDHG", "NL-PDHG"),
fig=fig,
ax=ax[0],
)
plot.plot(
snp.vstack((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T,
ptyp="semilogy",
title="Primal residual",
xlbl="Iteration",
lgnd=("PDHG", "NL-PDHG"),
fig=fig,
ax=ax[1],
)
plot.plot(
snp.vstack((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T,
ptyp="semilogy",
title="Dual residual",
xlbl="Iteration",
lgnd=("PDHG", "NL-PDHG"),
fig=fig,
ax=ax[2],
)
fig.show()


fig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10))
norm = plot.matplotlib.colors.Normalize(
vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()),
vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()),
)
plot.imview(snp.abs(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[0, 0], norm=norm)
plot.imview(
snp.abs(y),
title="Measured: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(y)),
cbar=None,
fig=fig,
ax=ax[0, 1],
norm=norm,
)
plot.imview(
snp.abs(x_tv),
title="TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)),
cbar=None,
fig=fig,
ax=ax[0, 2],
norm=norm,
)
plot.imview(
snp.abs(x_nltv),
title="NL-TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)),
cbar=None,
fig=fig,
ax=ax[0, 3],
norm=norm,
)
divider = make_axes_locatable(ax[0, 3])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[0, 3].get_images()[0], cax=cax)
norm = plot.matplotlib.colors.Normalize(
vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()),
vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()),
)
plot.imview(
snp.angle(x_gt),
title="Ground truth",
cbar=None,
fig=fig,
ax=ax[1, 0],
norm=norm,
)
plot.imview(
snp.angle(y),
title="Measured: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(),
cbar=None,
fig=fig,
ax=ax[1, 1],
norm=norm,
)
plot.imview(
snp.angle(x_tv),
title="TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(),
cbar=None,
fig=fig,
ax=ax[1, 2],
norm=norm,
)
plot.imview(
snp.angle(x_nltv),
title="NL-TV: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(),
cbar=None,
fig=fig,
ax=ax[1, 3],
norm=norm,
)
divider = make_axes_locatable(ax[1, 3])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1, 3].get_images()[0], cax=cax)
ax[0, 0].set_ylabel("Magnitude")
ax[1, 0].set_ylabel("Phase")
fig.tight_layout()
fig.show()


input("\nWaiting for input to close figures and exit")
2 changes: 1 addition & 1 deletion examples/scripts/denoise_l1tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"""
N = 256 # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, 240), 8)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
x_gt = 0.5 * x_gt / x_gt.max()
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
y = spnoise(x_gt, 0.5)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/denoise_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"""
N = 256 # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, 240), 8)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
x_gt = x_gt / x_gt.max()

Expand Down
3 changes: 3 additions & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Miscellaneous
^^^^^^^^^^^^^

- demosaic_ppp_bm3d_admm.py
- denoise_cplx_tv_pdhg.py
- denoise_l1tv_admm.py
- denoise_tv_admm.py
- denoise_tv_pgm.py
Expand Down Expand Up @@ -83,6 +84,7 @@ Total Variation
- deconv_microscopy_allchn_tv_admm.py
- deconv_tv_admm.py
- deconv_tv_admm_tune.py
- denoise_cplx_tv_pdhg.py
- denoise_l1tv_admm.py
- denoise_tv_admm.py
- denoise_tv_pgm.py
Expand Down Expand Up @@ -140,6 +142,7 @@ PDHG
^^^^

- ct_svmbir_tv_multi.py
- denoise_cplx_tv_pdhg.py
- denoise_tv_multi.py


Expand Down
23 changes: 22 additions & 1 deletion scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,33 @@ def jvp(self, primals, tangents):

return jax.jvp(self, primals, tangents)

def jhvp(self, *primals):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael-T-McCann: Thoughts on the name of the method? I'm open to suggestions if you think this one is not so great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind it given that jax has settled on jvp and vjp.

"""Compute a Jacobian-vector product with Hermitian transpose.

Compute the product :math:`[J(\mb{x})]^H \mb{v}` where
:math:`[J(\mb{x})]` is the Jacobian of the operator evaluated
at :math:`\mb{x}`. Instead of directly evaluating the product,
a function is returned that takes :math:`\mb{v}` as an argument.

Args:
primals: Sequence of values at which the Jacobian is
evaluated, with length equal to the number of positional
arguments of `_eval`.
"""

primals, self_vjp = jax.vjp(self, *primals)

def conj_vjp(tangent):
return jax.tree_map(jax.numpy.conj, self_vjp(tangent.conj()))

return primals, conj_vjp

def vjp(self, *primals):
"""Compute a vector-Jacobian product.

Args:
primals: Sequence of values at which the Jacobian is
evaluated, with length equal to the number of position
evaluated, with length equal to the number of positional
arguments of `_eval`.
"""

Expand Down
Loading