Skip to content

Commit

Permalink
Merge branch 'main' into brendt/2d_ct_rename
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Sep 28, 2023
2 parents e5d77ec + 085bbcc commit 1fd8e23
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
6 changes: 3 additions & 3 deletions scico/linop/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def __init__(
im_shape: Shape,
angles: ArrayLike,
det_length: Optional[int] = None,
do_dithering: bool = True,
dither: bool = True,
):
r"""
Args:
im_shape: Shape of input array.
angles: (num_angles,) array of angles in radians.
det_length: Length of detector, in ``None``, defaults to the
length of diagonal of `im_shape`.
do_dither: If ``True`` randomly shift pixel locations to
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
self.im_shape = im_shape
Expand Down Expand Up @@ -92,7 +92,7 @@ def compute_inds(angle: float) -> ArrayLike:
)

# dither
if do_dithering:
if dither:
key = jax.random.PRNGKey(0)
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5)

Expand Down
21 changes: 8 additions & 13 deletions scico/test/test_ray_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

try:
import ray
from scico.ray import report, tune
from scico.ray import train, tune

ray.init(num_cpus=1)
except ImportError as e:
pytest.skip("ray.tune not installed", allow_module_level=True)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_random_run():
def eval_params(config, reporter):
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
reporter(cost=cost)
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
Expand All @@ -40,12 +39,11 @@ def eval_params(config, reporter):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_random_tune():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
Expand All @@ -66,12 +64,11 @@ def eval_params(config):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_run():
def eval_params(config, reporter):
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
reporter(cost=cost)
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
Expand All @@ -90,12 +87,11 @@ def eval_params(config, reporter):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_tune():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
Expand All @@ -115,12 +111,11 @@ def eval_params(config):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_tune_alt_init():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
tuner = tune.Tuner(
Expand Down

0 comments on commit 1fd8e23

Please sign in to comment.