Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DnCNN with noise level input #349

Merged
merged 32 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
62750da
fixed bug
wjgancn Oct 3, 2022
ef41fc0
add DnCNN with nosie level input
wjgancn Oct 3, 2022
ab2566b
Merge branch 'wjgancn/dnnm_testing' of github.com:wjgancn/scico into …
wjgancn Oct 3, 2022
a8eda63
add DnCNN with nosie level input
wjgancn Oct 3, 2022
913188c
Merge branch 'wjgancn/dnnm_testing' of github.com:wjgancn/scico into …
wjgancn Oct 3, 2022
e87cb63
Merge branch 'main' into wjgancn/dnnm_testing
bwohlberg Oct 13, 2022
c3f669e
Merge branch 'lanl:main' into wjgancn/dnnm_testing
wjgancn Oct 24, 2022
f46f18a
revise based on suggestion
wjgancn Oct 29, 2022
91ae244
Merge branch 'lanl:main' into wjgancn/dnnm_testing
wjgancn Oct 29, 2022
af1e0ea
Merge branch 'main' into wjgancn/dnnm_testing
bwohlberg Nov 22, 2022
e2c5cce
Docs improvements and clean up
bwohlberg Nov 22, 2022
0e316b2
Minor improvement
bwohlberg Nov 22, 2022
f55e7bd
Minor improvement
bwohlberg Nov 22, 2022
a17260a
Correct submodule reference
bwohlberg Nov 22, 2022
f314b9c
Add tests
bwohlberg Nov 22, 2022
d229788
Update submodule
bwohlberg Nov 22, 2022
ed72003
Update contributor list
bwohlberg Nov 22, 2022
1e1629d
Apply isort
bwohlberg Nov 22, 2022
eb5d1ae
Merge branch 'main' into wjgancn/dnnm_testing
bwohlberg Nov 22, 2022
d837ae9
fix conflicts
wjgancn Dec 8, 2022
556c15e
fix conflicts
wjgancn Dec 8, 2022
293e53f
Merge branch 'lanl:main' into wjgancn/dnnm_testing
wjgancn Dec 8, 2022
f22f68f
fix a minor issue
wjgancn Dec 8, 2022
4c02e4a
Merge branch 'wjgancn/dnnm_testing' of github.com:wjgancn/scico into …
wjgancn Dec 8, 2022
a95072f
fix issues
wjgancn Dec 8, 2022
a409035
add Optional in denoiser.py
wjgancn Dec 8, 2022
15a3179
add more descriptions
wjgancn Dec 11, 2022
711a31e
Update examples index
bwohlberg Dec 12, 2022
c1465fd
Docstring changes
bwohlberg Dec 12, 2022
42cf0ce
Docstring and comment changes
bwohlberg Dec 12, 2022
b5d5213
Minor docstring edit
bwohlberg Dec 12, 2022
2cb70af
Update submodule
bwohlberg Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data
3 changes: 2 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Usage Examples
.. toctree::
:maxdepth: 1

.. include:: include/exampledepend.rst
.. include:: exampledepend.rst


Organized by Application
Expand Down Expand Up @@ -73,6 +73,7 @@ Miscellaneous
examples/denoise_tv_pgm
examples/denoise_tv_multi
examples/denoise_cplx_tv_pdhg
examples/denoise_dncnn_universal
examples/video_rpca_admm


Expand Down
54 changes: 34 additions & 20 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ @Article {almeida-2013-deconvolving
doi = {10.1109/TIP.2013.2258354}
}

@Article {balke-2022-scico,
author = {Thilo Balke and Fernando Davis and Cristina
Garcia-Cardona and Soumendu Majee and Michael McCann
and Luke Pfister and Brendt Wohlberg},
title = {Scientific Computational Imaging Code ({SCICO})},
journal = {Journal of Open Source Software},
year = 2022,
volume = 7,
number = 78,
pages = 4722,
doi = {10.21105/joss.04722}
}

@Article {barzilai-1988-stepsize,
author = {Jonathan Barzilai and Jonathan M. Borwein},
title = {Two-point step size gradient methods},
Expand Down Expand Up @@ -70,7 +83,8 @@ @InCollection{beck-2010-gradient
publisher = {Cambridge University Press},
year = 2010,
doi = {10.1017/CBO9780511804458.003},
url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf}
url =
{http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf}
}

@Book {beck-2017-first,
Expand Down Expand Up @@ -331,10 +345,10 @@ @Article {kamilov-2022-plug
T. Buzzard and Brendt Wohlberg},
title = {Plug-and-Play Methods for Integrating Physical and
Learned Models in Computational Imaging},
journal = {IEEE Signal Processing Magazine},
journal = {IEEE Signal Processing Magazine},
year = 2022,
eprint = {arXiv:2203.17061},
note = {To appear.}
note = {To appear.}
}

@Article {liu-2018-first,
Expand All @@ -361,7 +375,7 @@ @Article {maggioni-2012-nonlocal
number = 1,
pages = {119--133},
year = 2012,
doi = {10.1109/TIP.2012.2210725}
doi = {10.1109/TIP.2012.2210725}
}

@InProceedings {makinen-2019-exact,
Expand Down Expand Up @@ -414,7 +428,7 @@ @Book {nocedal-2006-numerical

@Book {paganin-2006-coherent,
doi = {10.1093/acprof:oso/9780198567288.001.0001},
isbn = {9780198567288},
isbn = 9780198567288,
year = 2006,
month = Jan,
publisher = {Oxford University Press},
Expand Down Expand Up @@ -481,19 +495,6 @@ @Article {sauer-1993-local
doi = {10.1109/78.193196}
}

@Article {balke-2022-scico,
author = {Thilo Balke and Fernando Davis and Cristina
Garcia-Cardona and Soumendu Majee and Michael McCann
and Luke Pfister and Brendt Wohlberg},
title = {Scientific Computational Imaging Code ({SCICO})},
journal = {Journal of Open Source Software},
year = {2022},
volume = {7},
number = {78},
pages = {4722},
doi = {10.21105/joss.04722}
}

@Article {soulez-2016-proximity,
author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut
and Antony Schutz and Andr{\'{e}} Ferrari and
Expand All @@ -506,7 +507,6 @@ @Article {soulez-2016-proximity
volume = 55,
number = 26,
pages = {7412--7421}

}

@Article {sreehari-2016-plug,
Expand Down Expand Up @@ -541,7 +541,7 @@ @Article {valkonen-2014-primal
journal = {Inverse Problems},
volume = 30,
number = 5,
pages = {055012},
pages = 055012,
year = 2014,
doi = {10.1088/0266-5611/30/5/055012}
}
Expand Down Expand Up @@ -609,6 +609,20 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}

@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
title = {Plug-and-Play Image Restoration With Deep Denoiser
Prior},
journal = {IEEE Transactions on Pattern Analysis and Machine
Intelligence},
year = 2022,
volume = 44,
number = 10,
doi = {10.1109/TPAMI.2021.3088914},
pages = {6360--6376}
}

@Article {zhou-2006-adaptive,
author = {Bin Zhou and Li Gao and Yu-Hong Dai},
title = {Gradient Methods with Adaptive Step-Sizes},
Expand Down
1 change: 1 addition & 0 deletions docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ Contributors
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Andrew Leong <https://scholar.google.com/citations?user=-2wRWbcAAAAJ&hl=en>`_ (Improvements to optics module documentation)
- `Weijie Gan <https://github.com/wjgancn>`_ (Non-blind variant of DnCNN)
2 changes: 2 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ Miscellaneous
Comparison of Optimization Algorithms for Total Variation Denoising
`denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_
Complex Total Variation Denoising
`denoise_dncnn_universal.py <denoise_dncnn_universal.py>`_
Comparison of DnCNN Variants for Image Denoising
`video_rpca_admm.py <video_rpca_admm.py>`_
Video Decomposition via Robust PCA

Expand Down
82 changes: 82 additions & 0 deletions examples/scripts/denoise_dncnn_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""
Comparison of DnCNN Variants for Image Denoising
================================================

This example demonstrates the solution of an image denoising problem
using DnCNN :cite:`zhang-2017-dncnn` networks trained for different noise
levels, as well as custom variants with fewer network layers, and with a
noise level input.

The networks trained for specific noise levels are labeled 6L, 6M, 6H,
17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M,
H} represent noise standard deviation of the training images (0.06, 0.10,
and 0.20 respectively). The networks with a noise standard deviation
input are labeled 6N and 17N, where {6, 17} again denote the number of
layers.
"""

import numpy as np

import jax

from xdesign import Foam, discrete_phantom

import scico.random
from scico import metric, plot
from scico.denoiser import DnCNN

"""
Create a ground truth image.
"""
np.random.seed(1234)
N = 512 # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU

"""
Test different DnCNN variants on images with different noise levels.
"""
print(" σ | variant | noisy image PSNR (dB) | denoised image PSNR (dB)")
for σ in [0.06, 0.10, 0.20]:
print("------+---------+-------------------------+-------------------------")
for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]:

# Instantiate a DnCNN.
denoiser = DnCNN(variant=variant)

# Generate a noisy image.
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

if variant in ["6N", "17N"]:
x_hat = denoiser(y, sigma=σ)
else:
x_hat = denoiser(y)

x_hat = np.clip(x_hat, a_min=0, a_max=1.0)

if variant[0] == "6":
variant += " " # add spaces to maintain alignment

print(
" %.2f | %s | %.2f | %.2f "
% (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat))
)


"""
Show reference and denoised images for σ=0.2 and variant=6N.
"""
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
plot.imview(x_gt, title="Reference", fig=fig, ax=ax[0])
plot.imview(y, title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1])
plot.imview(x_hat, title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2])
fig.show()

input("\nWaiting for input to close figures and exit")
1 change: 1 addition & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Miscellaneous
- denoise_tv_pgm.py
- denoise_tv_multi.py
- denoise_cplx_tv_pdhg.py
- denoise_dncnn_universal.py
- video_rpca_admm.py


Expand Down
76 changes: 62 additions & 14 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""Interfaces to standard denoisers."""


from typing import Any, Union
from typing import Any, Optional, Union

import numpy as np

Expand Down Expand Up @@ -53,8 +53,8 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False, profile: Union[BM3DPro
x: Input image. Expected to be a 2D array (gray-scale denoising)
or 3D array (color denoising). Higher-dimensional arrays are
tolerated only if the additional dimensions are singletons.
For color denoising, the color channel is assumed to be in the
last non-singleton dimension.
For color denoising, the color channel is assumed to be in
the last non-singleton dimension.
sigma: Noise parameter.
is_rgb: Flag indicating use of BM3D with a color transform.
Default: ``False``.
Expand Down Expand Up @@ -182,42 +182,75 @@ class DnCNN(FlaxMap):
Note that :class:`.DnCNNNet` represents an untrained form of the
generic DnCNN CNN structure, while this class represents a trained
form with six or seventeen layers.

The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not
have a noise level input. This implementation of DnCNN also supports
a custom variant that includes a noise standard deviation input,
`sigma`, which is included in the network as an additional channel
consisting of a constant array with value `sigma`. This network was
trained with image data on the range [0, 1], and with noise standard
deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet
:cite:`zhang-2021-plug`, another recent approach to including a noise
level input in a CNN denoiser, is based on a substantially different
network architecture.
"""

def __init__(self, variant: str = "6M"):
"""
Note that all DnCNN models are trained for single-channel image
input. Multi-channel input is supported via independent denoising
of each channel.
of each channel. Input images are expected to have pixel values
in the range [0, 1].

Args:
variant: Identify the DnCNN model to be used. Options are
'6L', '6M' (default), '6H', '17L', '17M', and '17H',
where the integer indicates the number of layers in the
network, and the postfix indicates the training noise
standard deviation: L (low) = 0.06, M (mid) = 0.1,
H (high) = 0.2, where the standard deviations are
with respect to data in the range [0, 1].
'6L', '6M' (default), '6H', '6N', '17L', '17M', '17H',
and '17N', where the integer indicates the number of
layers in the network, and the postfix indicates the
training noise standard deviation (with respect to data
in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10,
H (high) = 0.20, or N indicating that a noise standard
deviation input, `sigma`, is available.
"""
if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]:

self.variant = variant

if variant not in ["6L", "6M", "6H", "17L", "17M", "17H", "6N", "17N"]:
raise ValueError(f"Invalid value {variant} of parameter variant.")
if variant[0] == "6":
nlayer = 6
else:
nlayer = 17
model = DnCNNNet(depth=nlayer, channels=1, num_filters=64, dtype=np.float32)

channels = 2 if variant in ["6N", "17N"] else 1

model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32)
variables = load_weights(_flax_data_path("dncnn%s.npz" % variant))
super().__init__(model, variables)

def __call__(self, x: JaxArray) -> JaxArray:
def __call__(self, x: JaxArray, sigma: Optional[float] = None) -> JaxArray:
r"""Apply DnCNN denoiser.

Args:
x: Input array.
sigma: Noise standard deviation (for variants `6N` and `17N`).

Returns:
Denoised output.
"""
if sigma is not None and self.variant not in ["6N", "17N"]:
raise ValueError(
"A non-default value for the sigma parameter may "
"only be specified when the variant is 6N or 17N"
f"; got variant = {self.variant}."
)

if sigma is None and self.variant in ["6N", "17N"]:
raise ValueError(
"A float value must be specified for the sigma "
"parameter when the variant is 6N or 17N."
)

if snp.util.is_complex_dtype(x.dtype):
raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}.")

Expand All @@ -238,13 +271,28 @@ def __call__(self, x: JaxArray) -> JaxArray:
)

if x.ndim == 3:
y = snp.swapaxes(x, 0, -1)

if sigma is not None:
y = snp.stack([y, snp.ones_like(y) * sigma], -1)
else:
y = y[..., np.newaxis]

# swap channel axis to batch axis and add singleton axis at end
y = super().__call__(snp.swapaxes(x, 0, -1)[..., np.newaxis])
y = super().__call__(y)
# drop singleton axis and swap axes back to original positions
y = snp.swapaxes(y[..., 0], 0, -1)

else:
if sigma is not None:
x = snp.stack([x, snp.ones_like(x) * sigma], -1)
x = x[np.newaxis, ...]

y = super().__call__(x)

if sigma is not None:
y = y[0, ..., 0]

y = y.reshape(x_in_shape)

return y
Loading