From 3bafb079eaa140520baf9036b3d5b5c44944738d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 19:58:02 -0700 Subject: [PATCH 01/16] Minor cleanup --- scico/linop/abel.py | 4 ++-- scico/ray/tune.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/scico/linop/abel.py b/scico/linop/abel.py index cd0119f27..f50813922 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -65,10 +65,10 @@ def inverse(self, y: JaxArray) -> JaxArray: """Perform inverse Abel transform. Args: - y: Input image (assumed to be a result of an Abel transform) + y: Input image (assumed to be a result of an Abel transform). Returns: - Output of inverse Abel transform + Output of inverse Abel transform. """ return _pyabel_transform(y, direction="inverse", proj_mat_quad=self.proj_mat_quad).astype( self.input_dtype diff --git a/scico/ray/tune.py b/scico/ray/tune.py index a701bf7fd..726e6baac 100644 --- a/scico/ray/tune.py +++ b/scico/ray/tune.py @@ -24,8 +24,6 @@ from ray.tune.suggest.hyperopt import HyperOptSearch from ray.tune.trial import Trial -__author__ = """Brendt Wohlberg """ - class _CustomReporter(TuneReporterBase): """Custom status reporter for :mod:`ray.tune`.""" From 9f9216b33d4fc254654623b513e52286184b1f82 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 20:02:15 -0700 Subject: [PATCH 02/16] Make bm3d semi-optional --- scico/denoiser.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index d2a11b4ee..6dc894e4c 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -12,7 +12,12 @@ from jax.experimental import host_callback as hcb -import bm3d as tunibm3d +try: + import bm3d as tunibm3d +except ImportError: + have_bm3d = False +else: + have_bm3d = True import scico.numpy as snp from scico._flax import DnCNNNet, load_weights @@ -38,6 +43,8 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): Returns: Denoised output. """ + if not have_bm3d: + raise RuntimeError("Package bm3d is required for use of this function." "") if is_rgb is True: bm3d_eval = tunibm3d.bm3d_rgb @@ -45,7 +52,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): bm3d_eval = tunibm3d.bm3d if np.iscomplexobj(x): - raise TypeError(f"BM3D requries real-valued inputs, got {x.dtype}") + raise TypeError(f"BM3D requires real-valued inputs, got {x.dtype}") # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape @@ -94,7 +101,6 @@ class DnCNN(FlaxMap): def __init__(self, variant: str = "6M"): """ - Note that all DnCNN models are trained for single-channel image input. Multi-channel input is supported via independent denoising of each channel. From 0a37935245938647df76898a7b7143f79869a32e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 20:50:40 -0700 Subject: [PATCH 03/16] Update manifest --- MANIFEST.in | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 663d4cdb0..2a3a1f087 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,13 +1,14 @@ include MANIFEST.in include setup.py -include setup.cfg include README.rst include LICENSE include CHANGES.rst include requirements.txt +include dev_requirements.txt +include examples/scriptcheck.sh +include docs/docs_requirements.txt recursive-include scico *.py recursive-include scico/data *.png *.npz recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.ico -recursive-include examples Makefile *.txt *.rst *.py -include examples/pytojnb examples/scriptcheck +recursive-include examples *_requirements.txt *.txt *.rst *.py From 4004f8559d71be44f5123c54f04c2959cffe8678 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 20:51:11 -0700 Subject: [PATCH 04/16] Add extras_requires options --- setup.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index fb413baf9..7b30c9915 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def get_init_variable_value(var): """ # Set install_requires from requirements.txt file -with open(os.path.join("requirements.txt")) as f: +with open("requirements.txt") as f: lines = f.readlines() install_requires = [line.strip() for line in lines] @@ -59,8 +59,21 @@ def get_init_variable_value(var): f"requirements.txt ({req_jax_ver}) do not match" ) -tests_require = ["pytest", "pytest-runner"] python_requires = ">=3.8" +tests_require = ["pytest", "pytest-runner"] + +extra_require_files = [ + "dev_requirements.txt", + os.path.join("docs", "docs_requirements.txt"), + os.path.join("examples", "examples_requirements.txt"), + os.path.join("examples", "notebooks_requirements.txt"), +] +extras_require = {"tests": tests_require} +for require_file in extra_require_files: + extras_label = os.path.basename(require_file).partition("_")[0] + with open(require_file) as f: + lines = f.readlines() + extras_require[extras_label] = [line.strip() for line in lines if line[0:2] != "-r"] setup( @@ -69,7 +82,15 @@ def get_init_variable_value(var): description="Scientific Computational Imaging COde: A Python " "package for scientific imaging problems", long_description=longdesc, - keywords=["Computational Imaging", "Inverse Problems", "Optimization", "ADMM", "PGM"], + keywords=[ + "Computational Imaging", + "Inverse Problems", + "Optimization", + "ADMM", + "Linearized ADMM", + "PDHG", + "PGM", + ], platforms="Any", license="BSD", url="https://github.com/lanl/scico", From c9d6579ad42bb8e428e3c543d839c7c1ddcf0dea Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 20:51:48 -0700 Subject: [PATCH 05/16] Really add extras_requires options --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7b30c9915..b767cd63b 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,7 @@ def get_init_variable_value(var): python_requires=python_requires, tests_require=tests_require, install_requires=install_requires, - extras_require={"tests": tests_require}, + extras_require=extras_require, classifiers=[ "License :: OSI Approved :: BSD License", "Development Status :: 3 - Alpha", From d01595a13436c561408e8c7d1d809e577944ba5d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Mar 2022 20:58:07 -0700 Subject: [PATCH 06/16] Add keywords --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index b767cd63b..06cc5b05c 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,10 @@ def get_init_variable_value(var): long_description=longdesc, keywords=[ "Computational Imaging", + "Scientific Imaging", "Inverse Problems", + "Plug-and-Play Priors", + "Total Variation", "Optimization", "ADMM", "Linearized ADMM", From e75a5fb406b9330d5949fc44fc9f5a356aa64c89 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 17:44:13 -0700 Subject: [PATCH 07/16] Remove cached_property dependency now that xdesign dependencies correctly specified --- data | 2 +- examples/examples_requirements.txt | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/data b/data index d7b0478e3..db588186a 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit d7b0478e3769f0cb8ec72b7cf936eb2a997ee8f2 +Subproject commit db588186a4fd3297286448f5df9b7b0f8c0472a9 diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index c76f94e24..2d4eb54d4 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,6 +1,5 @@ astra-toolbox -cached_property colour_demosaicing -xdesign +xdesign>=0.5.5 ray[tune] hyperopt From 0d62aa53ac8cc148e97c41f912cf572ba18afe5b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 19:35:57 -0700 Subject: [PATCH 08/16] Avoid test failure if bm3d not installed --- scico/test/test_denoiser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scico/test/test_denoiser.py b/scico/test/test_denoiser.py index 81af5c35d..95aa6dee4 100644 --- a/scico/test/test_denoiser.py +++ b/scico/test/test_denoiser.py @@ -4,10 +4,11 @@ import pytest -from scico.denoiser import DnCNN, bm3d +from scico.denoiser import DnCNN, bm3d, have_bm3d from scico.random import randn +@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed") class TestBM3D: def setup(self): key = None From e8e4f548994edba2a306ab37b006b988a92051ad Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 19:43:28 -0700 Subject: [PATCH 09/16] Avoid test failure if bm3d not installed --- scico/test/test_functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scico/test/test_functional.py b/scico/test/test_functional.py index 3ce83358f..f95dae06d 100644 --- a/scico/test/test_functional.py +++ b/scico/test/test_functional.py @@ -15,6 +15,7 @@ import scico.numpy as snp from scico import denoiser, functional from scico.blockarray import BlockArray +from scico.denoiser import have_bm3d from scico.random import randn NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm] @@ -286,6 +287,7 @@ def foo(c): np.testing.assert_allclose(non_pmap, pmapped) +@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed") class TestBM3D: def setup(self): key = None From 71c287386a04e9a07b593a3e18d4babb03ff47bf Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 19:57:17 -0700 Subject: [PATCH 10/16] Fix checks for invalid input shapes --- scico/denoiser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index 6dc894e4c..4c51c1682 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -57,7 +57,9 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape - if x.ndim < 2: + print("ndim: ", x.ndim) + print("shape: ", x.shape) + if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "BM3D requires two dimensional (M, N) or three dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}" @@ -136,7 +138,7 @@ def __call__(self, x: JaxArray) -> JaxArray: if np.iscomplexobj(x): raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}") - if x.ndim < 2: + if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}" From 875e3cd4aebff5425febf6d54ae28a2af1720415 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 19:57:50 -0700 Subject: [PATCH 11/16] Fix blockarray tests, add new test --- scico/test/test_denoiser.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scico/test/test_denoiser.py b/scico/test/test_denoiser.py index 95aa6dee4..ae8cdbe3e 100644 --- a/scico/test/test_denoiser.py +++ b/scico/test/test_denoiser.py @@ -38,15 +38,19 @@ def test_bad_inputs(self): with pytest.raises(ValueError): bm3d(x, 1.0) - x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32) + x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32) + x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - z, key = randn((32, 32), key=None, dtype=np.complex64) + x, key = randn((5, 9), key=key, dtype=np.float32) + with pytest.raises(ValueError): + bm3d(x, 1.0) + + z, key = randn((32, 32), key=key, dtype=np.complex64) with pytest.raises(TypeError): bm3d(z, 1.0) @@ -81,7 +85,7 @@ def test_bad_inputs(self): with pytest.raises(ValueError): self.dncnn(x) - x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32) + x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) From 02a334b6fe2755e441b1b16670485667156344bd Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 20:13:42 -0700 Subject: [PATCH 12/16] Remove debugging code --- scico/denoiser.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index 4c51c1682..8dfffa95e 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -57,8 +57,6 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape - print("ndim: ", x.ndim) - print("shape: ", x.shape) if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "BM3D requires two dimensional (M, N) or three dimensional (M, N, C)" From cef233f6af0e0fc1ea96a22a4dafe30b3497ae8b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 20:14:41 -0700 Subject: [PATCH 13/16] Minor edit --- scico/denoiser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index 8dfffa95e..ffddcd8d8 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -69,7 +69,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): # updated; this presumes 'np' profile (bs=8) if np.min(x.shape[:2]) < 8: raise ValueError( - f"Two leading dimensions of input cannot be smaller than block size " + "Two leading dimensions of input cannot be smaller than block size " f"(8); got image size = {x.shape}" ) From df94e61e5eddd633c0165b8f0979ad448cb4bc5a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 21:06:11 -0700 Subject: [PATCH 14/16] Use more appropriate exception --- scico/denoiser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index ffddcd8d8..52ae82d77 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -115,7 +115,7 @@ def __init__(self, variant: str = "6M"): with respect to data in the range [0, 1]. """ if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]: - raise RuntimeError(f"Invalid value of parameter variant: {variant}") + raise ValueError(f"Invalid value of parameter variant: {variant}") if variant[0] == "6": nlayer = 6 else: From 54f071b8dd7b54aa9e81fccc75c40de5dc2353c9 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Mar 2022 21:06:29 -0700 Subject: [PATCH 15/16] Add some tests --- scico/test/test_denoiser.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scico/test/test_denoiser.py b/scico/test/test_denoiser.py index ae8cdbe3e..03d19f022 100644 --- a/scico/test/test_denoiser.py +++ b/scico/test/test_denoiser.py @@ -37,19 +37,15 @@ def test_bad_inputs(self): x, key = randn((32,), key=None, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - x, key = randn((5, 9), key=key, dtype=np.float32) with pytest.raises(ValueError): bm3d(x, 1.0) - z, key = randn((32, 32), key=key, dtype=np.complex64) with pytest.raises(TypeError): bm3d(z, 1.0) @@ -76,19 +72,24 @@ def test_multi_channel(self): assert no_jit.dtype == np.float32 assert jitted.dtype == np.float32 + def test_init(self): + dncnn = DnCNN(variant="6L") + x = dncnn(self.x_sngchn) + dncnn = DnCNN(variant="17H") + x = dncnn(self.x_mltchn) + with pytest.raises(ValueError): + dncnn = DnCNN(variant="3A") + def test_bad_inputs(self): x, key = randn((32,), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) - x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) - x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32) with pytest.raises(ValueError): self.dncnn(x) - z, key = randn((32, 32), key=None, dtype=np.complex64) with pytest.raises(TypeError): self.dncnn(z) From 442ca692a8a9099c46b7bb62344fd963385411c3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 8 Mar 2022 07:47:03 -0700 Subject: [PATCH 16/16] Clean up --- scico/denoiser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index 52ae82d77..9dd87b755 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -44,7 +44,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): Denoised output. """ if not have_bm3d: - raise RuntimeError("Package bm3d is required for use of this function." "") + raise RuntimeError("Package bm3d is required for use of this function.") if is_rgb is True: bm3d_eval = tunibm3d.bm3d_rgb