diff --git a/data b/data index c3b7db1f6..aa85087b4 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit c3b7db1f69967833d178e029713d5200f9fc6304 +Subproject commit aa85087b471f09a4165f1e465ed69ae33ec95183 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 36c3d7cb1..6901e9312 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -6,7 +6,7 @@ Usage Examples .. toctree:: :maxdepth: 1 -.. include:: include/exampledepend.rst +.. include:: exampledepend.rst Organized by Application @@ -73,6 +73,7 @@ Miscellaneous examples/denoise_tv_pgm examples/denoise_tv_multi examples/denoise_cplx_tv_pdhg + examples/denoise_dncnn_universal examples/video_rpca_admm diff --git a/docs/source/references.bib b/docs/source/references.bib index a1e83504e..8b598d432 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -23,6 +23,19 @@ @Article {almeida-2013-deconvolving doi = {10.1109/TIP.2013.2258354} } +@Article {balke-2022-scico, + author = {Thilo Balke and Fernando Davis and Cristina + Garcia-Cardona and Soumendu Majee and Michael McCann + and Luke Pfister and Brendt Wohlberg}, + title = {Scientific Computational Imaging Code ({SCICO})}, + journal = {Journal of Open Source Software}, + year = 2022, + volume = 7, + number = 78, + pages = 4722, + doi = {10.21105/joss.04722} +} + @Article {barzilai-1988-stepsize, author = {Jonathan Barzilai and Jonathan M. Borwein}, title = {Two-point step size gradient methods}, @@ -70,7 +83,8 @@ @InCollection{beck-2010-gradient publisher = {Cambridge University Press}, year = 2010, doi = {10.1017/CBO9780511804458.003}, - url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} + url = + {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} } @Book {beck-2017-first, @@ -331,10 +345,10 @@ @Article {kamilov-2022-plug T. Buzzard and Brendt Wohlberg}, title = {Plug-and-Play Methods for Integrating Physical and Learned Models in Computational Imaging}, - journal = {IEEE Signal Processing Magazine}, + journal = {IEEE Signal Processing Magazine}, year = 2022, eprint = {arXiv:2203.17061}, - note = {To appear.} + note = {To appear.} } @Article {liu-2018-first, @@ -361,7 +375,7 @@ @Article {maggioni-2012-nonlocal number = 1, pages = {119--133}, year = 2012, - doi = {10.1109/TIP.2012.2210725} + doi = {10.1109/TIP.2012.2210725} } @InProceedings {makinen-2019-exact, @@ -414,7 +428,7 @@ @Book {nocedal-2006-numerical @Book {paganin-2006-coherent, doi = {10.1093/acprof:oso/9780198567288.001.0001}, - isbn = {9780198567288}, + isbn = 9780198567288, year = 2006, month = Jan, publisher = {Oxford University Press}, @@ -481,19 +495,6 @@ @Article {sauer-1993-local doi = {10.1109/78.193196} } -@Article {balke-2022-scico, - author = {Thilo Balke and Fernando Davis and Cristina - Garcia-Cardona and Soumendu Majee and Michael McCann - and Luke Pfister and Brendt Wohlberg}, - title = {Scientific Computational Imaging Code ({SCICO})}, - journal = {Journal of Open Source Software}, - year = {2022}, - volume = {7}, - number = {78}, - pages = {4722}, - doi = {10.21105/joss.04722} -} - @Article {soulez-2016-proximity, author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut and Antony Schutz and Andr{\'{e}} Ferrari and @@ -506,7 +507,6 @@ @Article {soulez-2016-proximity volume = 55, number = 26, pages = {7412--7421} - } @Article {sreehari-2016-plug, @@ -541,7 +541,7 @@ @Article {valkonen-2014-primal journal = {Inverse Problems}, volume = 30, number = 5, - pages = {055012}, + pages = 055012, year = 2014, doi = {10.1088/0266-5611/30/5/055012} } @@ -609,6 +609,20 @@ @Article {zhang-2017-dncnn pages = {3142--3155} } +@Article {zhang-2021-plug, + author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and + Zhang, Lei and Van Gool, Luc and Timofte, Radu}, + title = {Plug-and-Play Image Restoration With Deep Denoiser + Prior}, + journal = {IEEE Transactions on Pattern Analysis and Machine + Intelligence}, + year = 2022, + volume = 44, + number = 10, + doi = {10.1109/TPAMI.2021.3088914}, + pages = {6360--6376} +} + @Article {zhou-2006-adaptive, author = {Bin Zhou and Li Gao and Yu-Hong Dai}, title = {Gradient Methods with Adaptive Step-Sizes}, diff --git a/docs/source/team.rst b/docs/source/team.rst index 8c707bf80..bf2cbda1a 100644 --- a/docs/source/team.rst +++ b/docs/source/team.rst @@ -43,3 +43,4 @@ Contributors - `Yanpeng Yuan `_ (ASTRA interface improvements) - `Saurav Maheshkar `_ (Improvements to pre-commit configuration) - `Andrew Leong `_ (Improvements to optics module documentation) +- `Weijie Gan `_ (Non-blind variant of DnCNN) diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index a82de9c18..75e0fd064 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -78,6 +78,8 @@ Miscellaneous Comparison of Optimization Algorithms for Total Variation Denoising `denoise_cplx_tv_pdhg.py `_ Complex Total Variation Denoising + `denoise_dncnn_universal.py `_ + Comparison of DnCNN Variants for Image Denoising `video_rpca_admm.py `_ Video Decomposition via Robust PCA diff --git a/examples/scripts/denoise_dncnn_universal.py b/examples/scripts/denoise_dncnn_universal.py new file mode 100644 index 000000000..0889ad089 --- /dev/null +++ b/examples/scripts/denoise_dncnn_universal.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +""" +Comparison of DnCNN Variants for Image Denoising +================================================ + +This example demonstrates the solution of an image denoising problem +using DnCNN :cite:`zhang-2017-dncnn` networks trained for different noise +levels, as well as custom variants with fewer network layers, and with a +noise level input. + +The networks trained for specific noise levels are labeled 6L, 6M, 6H, +17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M, +H} represent noise standard deviation of the training images (0.06, 0.10, +and 0.20 respectively). The networks with a noise standard deviation +input are labeled 6N and 17N, where {6, 17} again denote the number of +layers. +""" + +import numpy as np + +import jax + +from xdesign import Foam, discrete_phantom + +import scico.random +from scico import metric, plot +from scico.denoiser import DnCNN + +""" +Create a ground truth image. +""" +np.random.seed(1234) +N = 512 # image size +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU + +""" +Test different DnCNN variants on images with different noise levels. +""" +print(" σ | variant | noisy image PSNR (dB) | denoised image PSNR (dB)") +for σ in [0.06, 0.10, 0.20]: + print("------+---------+-------------------------+-------------------------") + for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]: + + # Instantiate a DnCNN. + denoiser = DnCNN(variant=variant) + + # Generate a noisy image. + noise, key = scico.random.randn(x_gt.shape, seed=0) + y = x_gt + σ * noise + + if variant in ["6N", "17N"]: + x_hat = denoiser(y, sigma=σ) + else: + x_hat = denoiser(y) + + x_hat = np.clip(x_hat, a_min=0, a_max=1.0) + + if variant[0] == "6": + variant += " " # add spaces to maintain alignment + + print( + " %.2f | %s | %.2f | %.2f " + % (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat)) + ) + + +""" +Show reference and denoised images for σ=0.2 and variant=6N. +""" +fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7)) +plot.imview(x_gt, title="Reference", fig=fig, ax=ax[0]) +plot.imview(y, title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1]) +plot.imview(x_hat, title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2]) +fig.show() + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index c14cf69d5..488b5e553 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -51,6 +51,7 @@ Miscellaneous - denoise_tv_pgm.py - denoise_tv_multi.py - denoise_cplx_tv_pdhg.py + - denoise_dncnn_universal.py - video_rpca_admm.py diff --git a/scico/denoiser.py b/scico/denoiser.py index 7c54a64b9..442f7d69a 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -8,7 +8,7 @@ """Interfaces to standard denoisers.""" -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np @@ -53,8 +53,8 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False, profile: Union[BM3DPro x: Input image. Expected to be a 2D array (gray-scale denoising) or 3D array (color denoising). Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. - For color denoising, the color channel is assumed to be in the - last non-singleton dimension. + For color denoising, the color channel is assumed to be in + the last non-singleton dimension. sigma: Noise parameter. is_rgb: Flag indicating use of BM3D with a color transform. Default: ``False``. @@ -182,42 +182,75 @@ class DnCNN(FlaxMap): Note that :class:`.DnCNNNet` represents an untrained form of the generic DnCNN CNN structure, while this class represents a trained form with six or seventeen layers. + + The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not + have a noise level input. This implementation of DnCNN also supports + a custom variant that includes a noise standard deviation input, + `sigma`, which is included in the network as an additional channel + consisting of a constant array with value `sigma`. This network was + trained with image data on the range [0, 1], and with noise standard + deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet + :cite:`zhang-2021-plug`, another recent approach to including a noise + level input in a CNN denoiser, is based on a substantially different + network architecture. """ def __init__(self, variant: str = "6M"): """ Note that all DnCNN models are trained for single-channel image input. Multi-channel input is supported via independent denoising - of each channel. + of each channel. Input images are expected to have pixel values + in the range [0, 1]. Args: variant: Identify the DnCNN model to be used. Options are - '6L', '6M' (default), '6H', '17L', '17M', and '17H', - where the integer indicates the number of layers in the - network, and the postfix indicates the training noise - standard deviation: L (low) = 0.06, M (mid) = 0.1, - H (high) = 0.2, where the standard deviations are - with respect to data in the range [0, 1]. + '6L', '6M' (default), '6H', '6N', '17L', '17M', '17H', + and '17N', where the integer indicates the number of + layers in the network, and the postfix indicates the + training noise standard deviation (with respect to data + in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10, + H (high) = 0.20, or N indicating that a noise standard + deviation input, `sigma`, is available. """ - if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]: + + self.variant = variant + + if variant not in ["6L", "6M", "6H", "17L", "17M", "17H", "6N", "17N"]: raise ValueError(f"Invalid value {variant} of parameter variant.") if variant[0] == "6": nlayer = 6 else: nlayer = 17 - model = DnCNNNet(depth=nlayer, channels=1, num_filters=64, dtype=np.float32) + + channels = 2 if variant in ["6N", "17N"] else 1 + + model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32) variables = load_weights(_flax_data_path("dncnn%s.npz" % variant)) super().__init__(model, variables) - def __call__(self, x: JaxArray) -> JaxArray: + def __call__(self, x: JaxArray, sigma: Optional[float] = None) -> JaxArray: r"""Apply DnCNN denoiser. Args: x: Input array. + sigma: Noise standard deviation (for variants `6N` and `17N`). Returns: Denoised output. """ + if sigma is not None and self.variant not in ["6N", "17N"]: + raise ValueError( + "A non-default value for the sigma parameter may " + "only be specified when the variant is 6N or 17N" + f"; got variant = {self.variant}." + ) + + if sigma is None and self.variant in ["6N", "17N"]: + raise ValueError( + "A float value must be specified for the sigma " + "parameter when the variant is 6N or 17N." + ) + if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}.") @@ -238,13 +271,28 @@ def __call__(self, x: JaxArray) -> JaxArray: ) if x.ndim == 3: + y = snp.swapaxes(x, 0, -1) + + if sigma is not None: + y = snp.stack([y, snp.ones_like(y) * sigma], -1) + else: + y = y[..., np.newaxis] + # swap channel axis to batch axis and add singleton axis at end - y = super().__call__(snp.swapaxes(x, 0, -1)[..., np.newaxis]) + y = super().__call__(y) # drop singleton axis and swap axes back to original positions y = snp.swapaxes(y[..., 0], 0, -1) + else: + if sigma is not None: + x = snp.stack([x, snp.ones_like(x) * sigma], -1) + x = x[np.newaxis, ...] + y = super().__call__(x) + if sigma is not None: + y = y[0, ..., 0] + y = y.reshape(x_in_shape) return y diff --git a/scico/test/test_denoiser.py b/scico/test/test_denoiser.py index 51869b250..8e36c8b0e 100644 --- a/scico/test/test_denoiser.py +++ b/scico/test/test_denoiser.py @@ -144,3 +144,24 @@ def test_bad_inputs(self): z, key = randn((32, 32), key=None, dtype=np.complex64) with pytest.raises(TypeError): self.dncnn(z) + + +class TestNonBLindDnCNN: + def setup_method(self): + key = None + self.x_sngchn, key = randn((32, 33), key=key, dtype=np.float32) + self.x_mltchn, key = randn((33, 34, 5), key=key, dtype=np.float32) + self.sigma = 0.1 + self.dncnn = DnCNN(variant="6N") + + def test_single_channel(self): + rslt = self.dncnn(self.x_sngchn, sigma=self.sigma) + assert rslt.dtype == np.float32 + + def test_multi_channel(self): + rslt = self.dncnn(self.x_mltchn, sigma=self.sigma) + assert rslt.dtype == np.float32 + + def test_bad_inputs(self): + with pytest.raises(ValueError): + rslt = self.dncnn(self.x_sngchn)