Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make bm3d dependency optional #243

Merged
merged 17 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
include MANIFEST.in
include setup.py
include setup.cfg
include README.rst
include LICENSE
include CHANGES.rst
include requirements.txt
include dev_requirements.txt
include examples/scriptcheck.sh
include docs/docs_requirements.txt

recursive-include scico *.py
recursive-include scico/data *.png *.npz
recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.ico
recursive-include examples Makefile *.txt *.rst *.py
include examples/pytojnb examples/scriptcheck
recursive-include examples *_requirements.txt *.txt *.rst *.py
3 changes: 1 addition & 2 deletions examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
astra-toolbox
cached_property
colour_demosaicing
xdesign
xdesign>=0.5.5
ray[tune]
hyperopt
20 changes: 13 additions & 7 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from jax.experimental import host_callback as hcb

import bm3d as tunibm3d
try:
import bm3d as tunibm3d
except ImportError:
have_bm3d = False
else:
have_bm3d = True

import scico.numpy as snp
from scico._flax import DnCNNNet, load_weights
Expand All @@ -38,19 +43,21 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False):
Returns:
Denoised output.
"""
if not have_bm3d:
raise RuntimeError("Package bm3d is required for use of this function.")

if is_rgb is True:
bm3d_eval = tunibm3d.bm3d_rgb
else:
bm3d_eval = tunibm3d.bm3d

if np.iscomplexobj(x):
raise TypeError(f"BM3D requries real-valued inputs, got {x.dtype}")
raise TypeError(f"BM3D requires real-valued inputs, got {x.dtype}")

# Support arrays with more than three axes when the additional axes are singletons.
x_in_shape = x.shape

if x.ndim < 2:
if isinstance(x.ndim, tuple) or x.ndim < 2:
tbalke marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"BM3D requires two dimensional (M, N) or three dimensional (M, N, C)"
f" inputs; got ndim = {x.ndim}"
Expand All @@ -62,7 +69,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False):
# updated; this presumes 'np' profile (bs=8)
if np.min(x.shape[:2]) < 8:
raise ValueError(
f"Two leading dimensions of input cannot be smaller than block size "
"Two leading dimensions of input cannot be smaller than block size "
f"(8); got image size = {x.shape}"
)

Expand Down Expand Up @@ -94,7 +101,6 @@ class DnCNN(FlaxMap):

def __init__(self, variant: str = "6M"):
"""

Note that all DnCNN models are trained for single-channel image
input. Multi-channel input is supported via independent denoising
of each channel.
Expand All @@ -109,7 +115,7 @@ def __init__(self, variant: str = "6M"):
with respect to data in the range [0, 1].
"""
if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]:
raise RuntimeError(f"Invalid value of parameter variant: {variant}")
raise ValueError(f"Invalid value of parameter variant: {variant}")
if variant[0] == "6":
nlayer = 6
else:
Expand All @@ -130,7 +136,7 @@ def __call__(self, x: JaxArray) -> JaxArray:
if np.iscomplexobj(x):
raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}")

if x.ndim < 2:
if isinstance(x.ndim, tuple) or x.ndim < 2:
raise ValueError(
"DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)"
f" inputs; got ndim = {x.ndim}"
Expand Down
4 changes: 2 additions & 2 deletions scico/linop/abel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def inverse(self, y: JaxArray) -> JaxArray:
"""Perform inverse Abel transform.

Args:
y: Input image (assumed to be a result of an Abel transform)
y: Input image (assumed to be a result of an Abel transform).

Returns:
Output of inverse Abel transform
Output of inverse Abel transform.
"""
return _pyabel_transform(y, direction="inverse", proj_mat_quad=self.proj_mat_quad).astype(
self.input_dtype
Expand Down
2 changes: 0 additions & 2 deletions scico/ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.trial import Trial

__author__ = """Brendt Wohlberg <[email protected]>"""


class _CustomReporter(TuneReporterBase):
"""Custom status reporter for :mod:`ray.tune`."""
Expand Down
28 changes: 17 additions & 11 deletions scico/test/test_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import pytest

from scico.denoiser import DnCNN, bm3d
from scico.denoiser import DnCNN, bm3d, have_bm3d
from scico.random import randn


@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed")
class TestBM3D:
def setup(self):
key = None
Expand Down Expand Up @@ -36,16 +37,16 @@ def test_bad_inputs(self):
x, key = randn((32,), key=None, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)
x, key = randn((12, 12, 4, 3), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
x, key = randn(((2, 3), (3, 4, 5)), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)

z, key = randn((32, 32), key=None, dtype=np.complex64)
x, key = randn((5, 9), key=key, dtype=np.float32)
with pytest.raises(ValueError):
bm3d(x, 1.0)
z, key = randn((32, 32), key=key, dtype=np.complex64)
with pytest.raises(TypeError):
bm3d(z, 1.0)

Expand All @@ -71,19 +72,24 @@ def test_multi_channel(self):
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

def test_init(self):
dncnn = DnCNN(variant="6L")
x = dncnn(self.x_sngchn)
dncnn = DnCNN(variant="17H")
x = dncnn(self.x_mltchn)
with pytest.raises(ValueError):
dncnn = DnCNN(variant="3A")

def test_bad_inputs(self):
x, key = randn((32,), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

x, key = randn((12, 12, 4, 3), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

x_b, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
x, key = randn(((2, 3), (3, 4, 5)), key=None, dtype=np.float32)
with pytest.raises(ValueError):
self.dncnn(x)

z, key = randn((32, 32), key=None, dtype=np.complex64)
with pytest.raises(TypeError):
self.dncnn(z)
2 changes: 2 additions & 0 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import scico.numpy as snp
from scico import denoiser, functional
from scico.blockarray import BlockArray
from scico.denoiser import have_bm3d
from scico.random import randn

NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm]
Expand Down Expand Up @@ -286,6 +287,7 @@ def foo(c):
np.testing.assert_allclose(non_pmap, pmapped)


@pytest.mark.skipif(not have_bm3d, reason="bm3d package not installed")
class TestBM3D:
def setup(self):
key = None
Expand Down
32 changes: 28 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_init_variable_value(var):
"""

# Set install_requires from requirements.txt file
with open(os.path.join("requirements.txt")) as f:
with open("requirements.txt") as f:
lines = f.readlines()
install_requires = [line.strip() for line in lines]

Expand All @@ -59,8 +59,21 @@ def get_init_variable_value(var):
f"requirements.txt ({req_jax_ver}) do not match"
)

tests_require = ["pytest", "pytest-runner"]
python_requires = ">=3.8"
tests_require = ["pytest", "pytest-runner"]

extra_require_files = [
"dev_requirements.txt",
os.path.join("docs", "docs_requirements.txt"),
os.path.join("examples", "examples_requirements.txt"),
os.path.join("examples", "notebooks_requirements.txt"),
]
extras_require = {"tests": tests_require}
for require_file in extra_require_files:
extras_label = os.path.basename(require_file).partition("_")[0]
with open(require_file) as f:
lines = f.readlines()
extras_require[extras_label] = [line.strip() for line in lines if line[0:2] != "-r"]


setup(
Expand All @@ -69,7 +82,18 @@ def get_init_variable_value(var):
description="Scientific Computational Imaging COde: A Python "
"package for scientific imaging problems",
long_description=longdesc,
keywords=["Computational Imaging", "Inverse Problems", "Optimization", "ADMM", "PGM"],
keywords=[
"Computational Imaging",
"Scientific Imaging",
"Inverse Problems",
"Plug-and-Play Priors",
"Total Variation",
"Optimization",
"ADMM",
"Linearized ADMM",
"PDHG",
"PGM",
],
platforms="Any",
license="BSD",
url="https://github.com/lanl/scico",
Expand All @@ -81,7 +105,7 @@ def get_init_variable_value(var):
python_requires=python_requires,
tests_require=tests_require,
install_requires=install_requires,
extras_require={"tests": tests_require},
extras_require=extras_require,
classifiers=[
"License :: OSI Approved :: BSD License",
"Development Status :: 3 - Alpha",
Expand Down