Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss functions for phase retrieval #236

Merged
merged 25 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,21 @@ @Article {sauer-1993-local
doi = {10.1109/78.193196}
}

@Article {soulez-2016-proximity,
author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut
and Antony Schutz and Andr{\'{e}} Ferrari and
Fr{\'{e}}d{\'{e}}ric Courbin and Michael Unser},
title = {Proximity operators for phase retrieval},
journal = {Applied Optics},
doi = {10.1364/ao.55.007412},
year = 2016,
month = Sep,
volume = 55,
number = 26,
pages = {7412--7421}

}

@Article {sreehari-2016-plug,
author = {Suhas Sreehari and Singanallur V. Venkatakrishnan
and Brendt Wohlberg and Gregery T. Buzzard and
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def postprocess(x):
lambda_weighted = 1.14e2

weights = jax.device_put(counts / Io)
f = loss.WeightedSquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))

admm_weighted = ADMM(
f=f,
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

This version uses the data fidelity term as the ADMM f, and thus the
optimization with respect to the data fidelity uses CG rather than the
prox of the SVMBIRWeightedSquaredL2Loss functional.
prox of the SVMBIRSquaredL2Loss functional.
"""

import numpy as np
Expand All @@ -32,7 +32,7 @@
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand Down Expand Up @@ -92,7 +92,7 @@
ρ = 15 # ADMM penalty parameter
σ = density * 0.18 # denoiser sigma

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

This version uses the data fidelity term as one of the ADMM g functionals,
and thus the optimization with respect to the data fidelity is able to
exploit the internal prox of the SVMBIRWeightedSquaredL2Loss functional.
exploit the internal prox of the SVMBIRSquaredL2Loss functional.
"""

import numpy as np
Expand All @@ -32,7 +32,7 @@
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand Down Expand Up @@ -92,7 +92,7 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

f = SVMBIRWeightedSquaredL2Loss(
f = SVMBIRSquaredL2Loss(
y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0}
)
g0 = σ * ρ * BM3D()
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ct_svmbir_tv_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import scico.numpy as snp
from scico import functional, linop, metric, plot
from scico.linop import Diagonal
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss
from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss
from scico.optimize import PDHG, LinearizedADMM
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
Expand Down Expand Up @@ -84,7 +84,7 @@

λ = 1e-1 # L1 norm regularization parameter

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g = λ * functional.L21Norm() # regularization functional

# The append=0 option makes the results of horizontal and vertical finite
Expand Down
35 changes: 16 additions & 19 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax.experimental.host_callback

import scico.numpy as snp
from scico.loss import Loss, WeightedSquaredL2Loss
from scico.loss import Loss, SquaredL2Loss
from scico.typing import JaxArray, Shape

from ._linop import Diagonal, Identity, LinearOperator
Expand All @@ -38,9 +38,9 @@ class ParallelBeamProjector(LinearOperator):
``is_masked`` option selects whether a valid region for projections
(pixels outside this region are ignored when performing the
projection) is active. This region of validity is also respected by
:meth:`.SVMBIRWeightedSquaredL2Loss.prox` when
:class:`.SVMBIRWeightedSquaredL2Loss` is initialized with a
:class:`ParallelBeamProjector` with this option enabled.
:meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss`
is initialized with a :class:`ParallelBeamProjector` with this option
enabled.
"""

def __init__(
Expand Down Expand Up @@ -178,8 +178,7 @@ def _bproj_hcb(self, y):


class SVMBIRExtendedLoss(Loss):
r"""Extended Weighted squared :math:`\ell_2` loss with svmbir CT
projector.
r"""Extended squared :math:`\ell_2` loss with svmbir CT projector.

Generalization of the weighted squared :math:`\ell_2` loss of a CT
reconstruction problem,
Expand All @@ -195,13 +194,12 @@ class SVMBIRExtendedLoss(Loss):
to :class:`scico.linop.Identity`.

The extended loss differs from a typical weighted squared
:math:`\ell_2` loss as follows.
When ``positivity=True``, the prox projects onto the non-negative
orthant and the loss is infinite if any element of the input is
negative. When the ``is_masked`` option of the associated
:class:`.ParallelBeamProjector` is ``True``, the reconstruction is
computed over a masked region of the image as described
in class :class:`.ParallelBeamProjector`.
:math:`\ell_2` loss as follows. When `positivity=True`, the prox
projects onto the non-negative orthant and the loss is infinite if
any element of the input is negative. When the `is_masked` option
of the associated :class:`.ParallelBeamProjector` is ``True``, the
reconstruction is computed over a masked region of the image as
described in class :class:`.ParallelBeamProjector`.
"""

def __init__(
Expand Down Expand Up @@ -299,7 +297,7 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
return jax.device_put(result.reshape(self.A.input_shape))


class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, WeightedSquaredL2Loss):
class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss):
r"""Weighted squared :math:`\ell_2` loss with svmbir CT projector.

Weighted squared :math:`\ell_2` loss of a CT reconstruction problem,
Expand All @@ -309,8 +307,8 @@ class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, WeightedSquaredL2Loss):
\alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} -
A(\mb{x})\right) \;,

where :math:`A` is a :class:`.ParallelBeamProjector`,
:math:`\alpha` is the scaling parameter and :math:`W` is an instance
where :math:`A` is a :class:`.ParallelBeamProjector`, :math:`\alpha`
is the scaling parameter and :math:`W` is an instance
of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set
to :class:`scico.linop.Identity`.
"""
Expand All @@ -321,7 +319,7 @@ def __init__(
prox_kwargs: Optional[dict] = None,
**kwargs,
):
r"""Initialize a :class:`SVMBIRWeightedSquaredL2Loss` object.
r"""Initialize a :class:`SVMBIRSquaredL2Loss` object.

Args:
y: Sinogram measurement.
Expand All @@ -337,8 +335,7 @@ def __init__(

if self.A.is_masked:
raise ValueError(
"is_masked must be false for the ParallelBeamProjector in "
"SVMBIRWeightedSquaredL2Loss."
"is_masked must be false for the ParallelBeamProjector in " "SVMBIRSquaredL2Loss."
)


Expand Down
Loading