Skip to content

Commit

Permalink
Make bm3d dependency optional (#243)
Browse files Browse the repository at this point in the history
* Minor cleanup

* Make bm3d semi-optional

* Update manifest

* Add extras_requires options

* Really add extras_requires options

* Add keywords

* Remove cached_property dependency now that xdesign dependencies correctly specified

* Avoid test failure if bm3d not installed

* Avoid test failure if bm3d not installed

* Fix checks for invalid input shapes

* Fix blockarray tests, add new test

* Remove debugging code

* Minor edit

* Use more appropriate exception

* Add some tests

* Clean up
  • Loading branch information
bwohlberg authored Mar 8, 2022
1 parent 13be7a3 commit 956e55e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 31 deletions.
7 changes: 4 additions & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
include MANIFEST.in
include setup.py
include setup.cfg
include README.rst
include LICENSE
include CHANGES.rst
include requirements.txt
include dev_requirements.txt
include examples/scriptcheck.sh
include docs/docs_requirements.txt

recursive-include scico *.py
recursive-include scico/data *.png *.npz
recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.ico
recursive-include examples Makefile *.txt *.rst *.py
include examples/pytojnb examples/scriptcheck
recursive-include examples *_requirements.txt *.txt *.rst *.py
3 changes: 1 addition & 2 deletions examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
astra-toolbox
cached_property
colour_demosaicing
xdesign
xdesign>=0.5.5
ray[tune]
hyperopt
20 changes: 13 additions & 7 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from jax.experimental import host_callback as hcb

import bm3d as tunibm3d
try:
import bm3d as tunibm3d
except ImportError:
have_bm3d = False
else:
have_bm3d = True

import scico.numpy as snp
from scico._flax import DnCNNNet, load_weights
Expand All @@ -38,19 +43,21 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False):
Returns:
Denoised output.
"""
if not have_bm3d:
raise RuntimeError("Package bm3d is required for use of this function.")

if is_rgb is True:
bm3d_eval = tunibm3d.bm3d_rgb
else:
bm3d_eval = tunibm3d.bm3d

if np.iscomplexobj(x):
raise TypeError(f"BM3D requries real-valued inputs, got {x.dtype}")
raise TypeError(f"BM3D requires real-valued inputs, got {x.dtype}")

# Support arrays with more than three axes when the additional axes are singletons.
x_in_shape = x.shape

if x.ndim < 2:
if isinstance(x.ndim, tuple) or x.ndim < 2:
raise ValueError(
"BM3D requires two dimensional (M, N) or three dimensional (M, N, C)"
f" inputs; got ndim = {x.ndim}"
Expand All @@ -62,7 +69,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False):
# updated; this presumes 'np' profile (bs=8)
if np.min(x.shape[:2]) < 8:
raise ValueError(
f"Two leading dimensions of input cannot be smaller than block size "
"Two leading dimensions of input cannot be smaller than block size "
f"(8); got image size = {x.shape}"
)

Expand Down Expand Up @@ -94,7 +101,6 @@ class DnCNN(FlaxMap):

def __init__(self, variant: str = "6M"):
"""
Note that all DnCNN models are trained for single-channel image
input. Multi-channel input is supported via independent denoising
of each channel.
Expand All @@ -109,7 +115,7 @@ def __init__(self, variant: str = "6M"):
with respect to data in the range [0, 1].
"""
if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]:
raise RuntimeError(f"Invalid value of parameter variant: {variant}")
raise ValueError(f"Invalid value of parameter variant: {variant}")
if variant[0] == "6":
nlayer = 6
else:
Expand All @@ -130,7 +136,7 @@ def __call__(self, x: JaxArray) -> JaxArray:
if np.iscomplexobj(x):
raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}")

if x.ndim < 2:
if isinstance(x.ndim, tuple) or x.ndim < 2:
raise ValueError(
"DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)"
f" inputs; got ndim = {x.ndim}"
Expand Down
4 changes: 2 additions & 2 deletions scico/linop/abel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def inverse(self, y: JaxArray) -> JaxArray:
"""Perform inverse Abel transform.
Args:
y: Input image (assumed to be a result of an Abel transform)
y: Input image (assumed to be a result of an Abel transform).
Returns:
Output of inverse Abel transform
Output of inverse Abel transform.
"""
return _pyabel_transform(y, direction="inverse", proj_mat_quad=self.proj_mat_quad).astype(
self.input_dtype
Expand Down
2 changes: 0 additions & 2 deletions scico/ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.trial import Trial

__author__ = """Brendt Wohlberg <[email protected]>"""


class _CustomReporter(TuneReporterBase):
"""Custom status reporter for :mod:`ray.tune`."""
Expand Down
28 changes: 17 additions & 11 deletions scico/test/test_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import pytest

from scico.denoiser import DnCNN, bm3d
from scico.denoiser import DnCNN, bm3d, have_bm3d
from scico.random import randn


@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed")
class TestBM3D:
def setup(self):
key = None
Expand Down Expand Up @@ -36,16 +37,16 @@ def test_bad_inputs(self):
x, key = randn((32,), key=None, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)
x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

z, key = randn((32, 32), key=None, dtype=np.complex64)
x, key = randn((5, 9), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)
z, key = randn((32, 32), key=key, dtype=np.complex64)
with pytest.raises(TypeError):
bm3d(z, 1.0)

Expand All @@ -71,19 +72,24 @@ def test_multi_channel(self):
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

def test_init(self):
dncnn = DnCNN(variant="6L")
x = dncnn(self.x_sngchn)
dncnn = DnCNN(variant="17H")
x = dncnn(self.x_mltchn)
with pytest.raises(ValueError):
dncnn = DnCNN(variant="3A")

def test_bad_inputs(self):
x, key = randn((32,), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

z, key = randn((32, 32), key=None, dtype=np.complex64)
with pytest.raises(TypeError):
self.dncnn(z)
2 changes: 2 additions & 0 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import scico.numpy as snp
from scico import denoiser, functional
from scico.blockarray import BlockArray
from scico.denoiser import have_bm3d
from scico.random import randn

NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm]
Expand Down Expand Up @@ -286,6 +287,7 @@ def foo(c):
np.testing.assert_allclose(non_pmap, pmapped)


@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed")
class TestBM3D:
def setup(self):
key = None
Expand Down
32 changes: 28 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_init_variable_value(var):
"""

# Set install_requires from requirements.txt file
with open(os.path.join("requirements.txt")) as f:
with open("requirements.txt") as f:
lines = f.readlines()
install_requires = [line.strip() for line in lines]

Expand All @@ -59,8 +59,21 @@ def get_init_variable_value(var):
f"requirements.txt ({req_jax_ver}) do not match"
)

tests_require = ["pytest", "pytest-runner"]
python_requires = ">=3.8"
tests_require = ["pytest", "pytest-runner"]

extra_require_files = [
"dev_requirements.txt",
os.path.join("docs", "docs_requirements.txt"),
os.path.join("examples", "examples_requirements.txt"),
os.path.join("examples", "notebooks_requirements.txt"),
]
extras_require = {"tests": tests_require}
for require_file in extra_require_files:
extras_label = os.path.basename(require_file).partition("_")[0]
with open(require_file) as f:
lines = f.readlines()
extras_require[extras_label] = [line.strip() for line in lines if line[0:2] != "-r"]


setup(
Expand All @@ -69,7 +82,18 @@ def get_init_variable_value(var):
description="Scientific Computational Imaging COde: A Python "
"package for scientific imaging problems",
long_description=longdesc,
keywords=["Computational Imaging", "Inverse Problems", "Optimization", "ADMM", "PGM"],
keywords=[
"Computational Imaging",
"Scientific Imaging",
"Inverse Problems",
"Plug-and-Play Priors",
"Total Variation",
"Optimization",
"ADMM",
"Linearized ADMM",
"PDHG",
"PGM",
],
platforms="Any",
license="BSD",
url="https://github.com/lanl/scico",
Expand All @@ -81,7 +105,7 @@ def get_init_variable_value(var):
python_requires=python_requires,
tests_require=tests_require,
install_requires=install_requires,
extras_require={"tests": tests_require},
extras_require=extras_require,
classifiers=[
"License :: OSI Approved :: BSD License",
"Development Status :: 3 - Alpha",
Expand Down

0 comments on commit 956e55e

Please sign in to comment.