From 18e39189d0d15aee0f456eee197f822586989c63 Mon Sep 17 00:00:00 2001 From: Michael McCann Date: Fri, 28 Jul 2023 12:52:59 -0600 Subject: [PATCH 01/13] Add 2d projector and code to time it --- examples/scripts/ct_projector_timing.py | 85 +++++++++++++++ scico/linop/__init__.py | 3 + scico/linop/_xray.py | 135 ++++++++++++++++++++++++ scico/util.py | 1 + 4 files changed, 224 insertions(+) create mode 100644 examples/scripts/ct_projector_timing.py create mode 100644 scico/linop/_xray.py diff --git a/examples/scripts/ct_projector_timing.py b/examples/scripts/ct_projector_timing.py new file mode 100644 index 000000000..f5569058e --- /dev/null +++ b/examples/scripts/ct_projector_timing.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + + +import time + +import jax +import jax.numpy as jnp +from xdesign import Foam, discrete_phantom + +from scico.linop import XRayProject, ParallelFixedAxis2dProjector +from scico.linop.radon_astra import TomographicProjector +from scico.util import Timer + + +N = 512 +num_angles = 512 + +det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) + +x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) +x_gt = jax.device_put(x_gt) + +angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) + +method_names = ["scico", "astra"] +timer = Timer( + [n + "_init" for n in method_names] + + [n + "_first_proj" for n in method_names] + + [n + "_avg_proj" for n in method_names] +) + +projectors = {} +timer.start("scico_init") +projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles)) +timer.stop("scico_init") + +timer.start("astra_init") +projectors["astra"] = TomographicProjector( + (N, N), detector_spacing=1.0, det_count=det_count, angles=angles +) +timer.stop("astra_init") + +ys = {} +for name, H in projectors.items(): + timer_label = f"{name}_first_proj" + timer.start(timer_label) + ys[name] = H @ x_gt + timer.stop(timer_label) + + +num_repeats = 3 +for name, H in projectors.items(): + timer_label = f"{name}_avg_proj" + timer.start(timer_label) + for _ in range(num_repeats): + ys[name] = H @ x_gt + timer.stop(timer_label) + timer.td[timer_label] /= num_repeats + + +print(timer) + +""" +with way 2: +Label Accum. Current +------------------------------------------- +astra_avg_proj 7.30e-01 s Stopped +astra_first_proj 7.41e-01 s Stopped +astra_init 4.63e-03 s Stopped +scico_avg_proj 9.96e-01 s Stopped +scico_first_proj 9.98e-01 s Stopped +scico_init 8.02e+00 s Stopped +""" + +fig, ax = plt.subplots() +ax.imshow(ys["scico"]) +fig.show() + +fig, ax = plt.subplots() +ax.imshow(ys["astra"]) +fig.show() diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 8fca29a1f..d01708800 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -18,6 +18,7 @@ from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator from ._stack import DiagonalStack, VerticalStack +from ._xray import XRayProject, ParallelFixedAxis2dProjector from ._util import jacobian, operator_norm, power_iteration, valid_adjoint __all__ = [ @@ -38,6 +39,8 @@ "Sum", "Transpose", "LinearOperator", + "XRayProject", + "ParallelFixedAxis2dProjector", "ComposedLinearOperator", "linop_from_function", "operator_norm", diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py new file mode 100644 index 000000000..2a236e867 --- /dev/null +++ b/scico/linop/_xray.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + + +""" +X-ray projector classes. +""" +from functools import partial + +import numpy as np +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from ._linop import LinearOperator + + +class XRayProject(LinearOperator): + """options to select type of projection""" + + def __init__(self, projector): + self._eval = projector.project + + super().__init__( + input_shape=projector.im_shape, + output_shape=(len(projector.angles), *projector.det_shape), + ) + + +class ParallelFixedAxis2dProjector: + """Parallel ray, single axis, 2D X-ray projector""" + + def __init__(self, im_shape, angles, det_length=None, dither=True): + self.im_shape = im_shape + self.angles = angles + + im_shape = np.array(im_shape) + + x0 = -(im_shape - 1) / 2 + + if det_length is None: + det_length = int(np.ceil(np.linalg.norm(im_shape))) + self.det_shape = (det_length,) + + y0 = -det_length / 2 + + @jax.vmap + def compute_inds(angle: float) -> ArrayLike: + # fast, but does not allow easy dithering + # dydx = jnp.stack((jnp.cos(angle), jnp.sin(angle))) + # Px0 = jnp.dot(x0, dydx) + # Px = ( + # Px0 + # + dydx[0] * jnp.arange(im_shape[0])[:, jnp.newaxis] + # + dydx[1] * jnp.arange(im_shape[1])[jnp.newaxis, :] + # ) + + x = jnp.stack( + jnp.meshgrid( + *( + jnp.arange(shape_i) * step_i + start_i + for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape) + ), + indexing="ij", + ), + axis=-1, + ) + + # dither + if dither: + key = jax.random.PRNGKey(0) + x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) + + # project + Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle) + + # quantize + inds = jnp.floor((Px - y0)).astype(int) + + # map negative inds to y_size, which is out of bounds and will be ignored + # otherwise they index from the end like x[-1] + inds = jnp.where(inds < 0, det_length, inds) + + return inds + + inds = compute_inds(angles) + + @partial(jax.vmap, in_axes=(None, 0)) + def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: + return jnp.zeros(det_length).at[inds].add(im) + + @jax.jit + def project(im: ArrayLike) -> ArrayLike: + return project_inds(im, inds) + + self.project = project + + +# num_angles = 127 + +# x = jnp.ones((128, 129)) + + +# H = ParallelFixedAxis2dProjector(x.shape, angles) +# y1 = H.project(x) + +# import matplotlib.pyplot as plt + +# fig, ax = plt.subplots() +# ax.imshow(y) +# fig.show() + +# f = lambda x: H.project(x)[65, 90] +# grad_f = jax.grad(f) + +# fig, ax = plt.subplots() +# ax.imshow(grad_f(x)) +# fig.show() + + +# ## back project + + +# bad_angle = jnp.array([jnp.pi / 4]) +# H = ParallelFixedAxis2dProjector(x.shape, bad_angle) +# y = H.project(x) + + +# fig, ax = plt.subplots() +# ax.plot(y[0]) +# fig.show() diff --git a/scico/util.py b/scico/util.py index d57c8efc6..2c3cf44c1 100644 --- a/scico/util.py +++ b/scico/util.py @@ -395,6 +395,7 @@ def __str__(self) -> str: s += "-" * (lfldln + 25) + "\n" # Construct table of timer details for lbl in sorted(self.t0): + print(lbl) td = self.td[lbl] if self.t0[lbl] is None: ts = " Stopped" From 099d6e4742b596ff8c47674b91d083d36d21dbc9 Mon Sep 17 00:00:00 2001 From: Michael McCann Date: Fri, 28 Jul 2023 14:15:52 -0600 Subject: [PATCH 02/13] Clean up --- ...r_timing.py => ct_projector_comparison.py} | 68 ++++++++++++------- scico/linop/_xray.py | 38 +---------- 2 files changed, 47 insertions(+), 59 deletions(-) rename examples/scripts/{ct_projector_timing.py => ct_projector_comparison.py} (59%) diff --git a/examples/scripts/ct_projector_timing.py b/examples/scripts/ct_projector_comparison.py similarity index 59% rename from examples/scripts/ct_projector_timing.py rename to examples/scripts/ct_projector_comparison.py index f5569058e..e24c9f0ec 100644 --- a/examples/scripts/ct_projector_timing.py +++ b/examples/scripts/ct_projector_comparison.py @@ -5,33 +5,45 @@ # with the package. -import time +r""" +X-ray Projector Comparison +========================== + +This example compares SCICO's native X-ray projection algorithm +to that of the ASTRA Toolbox. +""" import jax import jax.numpy as jnp + from xdesign import Foam, discrete_phantom -from scico.linop import XRayProject, ParallelFixedAxis2dProjector +from scico import plot +from scico.linop import ParallelFixedAxis2dProjector, XRayProject from scico.linop.radon_astra import TomographicProjector from scico.util import Timer +""" +Create a ground truth image. +""" N = 512 -num_angles = 512 + det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jax.device_put(x_gt) +""" +Time projector instantiation. +""" + +num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) -method_names = ["scico", "astra"] -timer = Timer( - [n + "_init" for n in method_names] - + [n + "_first_proj" for n in method_names] - + [n + "_avg_proj" for n in method_names] -) + +timer = Timer() projectors = {} timer.start("scico_init") @@ -44,42 +56,52 @@ ) timer.stop("astra_init") +""" +Time first projector application, which might include JIT overhead. +""" + ys = {} for name, H in projectors.items(): timer_label = f"{name}_first_proj" timer.start(timer_label) ys[name] = H @ x_gt + jax.block_until_ready(ys[name]) timer.stop(timer_label) +""" +Compute average time for a projector application. +""" + num_repeats = 3 for name, H in projectors.items(): timer_label = f"{name}_avg_proj" timer.start(timer_label) for _ in range(num_repeats): ys[name] = H @ x_gt + jax.block_until_ready(ys[name]) timer.stop(timer_label) timer.td[timer_label] /= num_repeats +""" +Display timing results. + +On our server, using the GPU: + + +Using the CPU: + +""" print(timer) """ -with way 2: -Label Accum. Current -------------------------------------------- -astra_avg_proj 7.30e-01 s Stopped -astra_first_proj 7.41e-01 s Stopped -astra_init 4.63e-03 s Stopped -scico_avg_proj 9.96e-01 s Stopped -scico_first_proj 9.98e-01 s Stopped -scico_init 8.02e+00 s Stopped +Show projections. """ -fig, ax = plt.subplots() -ax.imshow(ys["scico"]) +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 5)) +plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) +plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) fig.show() -fig, ax = plt.subplots() -ax.imshow(ys["astra"]) -fig.show() +input("\nWaiting for input to close figures and exit") diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py index 2a236e867..6f53e9ff0 100644 --- a/scico/linop/_xray.py +++ b/scico/linop/_xray.py @@ -12,6 +12,7 @@ from functools import partial import numpy as np + import jax import jax.numpy as jnp from jax.typing import ArrayLike @@ -87,7 +88,7 @@ def compute_inds(angle: float) -> ArrayLike: return inds - inds = compute_inds(angles) + inds = compute_inds(angles) # (len(angles), *im_shape) @partial(jax.vmap, in_axes=(None, 0)) def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: @@ -98,38 +99,3 @@ def project(im: ArrayLike) -> ArrayLike: return project_inds(im, inds) self.project = project - - -# num_angles = 127 - -# x = jnp.ones((128, 129)) - - -# H = ParallelFixedAxis2dProjector(x.shape, angles) -# y1 = H.project(x) - -# import matplotlib.pyplot as plt - -# fig, ax = plt.subplots() -# ax.imshow(y) -# fig.show() - -# f = lambda x: H.project(x)[65, 90] -# grad_f = jax.grad(f) - -# fig, ax = plt.subplots() -# ax.imshow(grad_f(x)) -# fig.show() - - -# ## back project - - -# bad_angle = jnp.array([jnp.pi / 4]) -# H = ParallelFixedAxis2dProjector(x.shape, bad_angle) -# y = H.project(x) - - -# fig, ax = plt.subplots() -# ax.plot(y[0]) -# fig.show() From 542c29907fa4f1216f68af66c324cb5ae1686a95 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:12:28 -0600 Subject: [PATCH 03/13] Add back projection --- examples/scripts/ct_projector_comparison.py | 89 ++++++++++++++++++++- scico/linop/_xray.py | 51 ++++++++---- scico/util.py | 1 - 3 files changed, 122 insertions(+), 19 deletions(-) diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index e24c9f0ec..c8d150881 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -13,6 +13,8 @@ to that of the ASTRA Toolbox. """ +import numpy as np + import jax import jax.numpy as jnp @@ -86,10 +88,31 @@ """ Display timing results. -On our server, using the GPU: +On our server, the SCICO projection is more than twice +as fast as ASTRA when run on the GPU, and about about +10% slower on the CPU. +On our server, using the GPU: +Label Accum. Current +------------------------------------------- +astra_avg_proj 4.62e-02 s Stopped +astra_first_proj 6.92e-02 s Stopped +astra_init 1.36e-03 s Stopped +scico_avg_proj 1.61e-02 s Stopped +scico_first_proj 2.95e-02 s Stopped +scico_init 1.37e+01 s Stopped Using the CPU: +Label Accum. Current +------------------------------------------- +astra_avg_proj 9.11e-01 s Stopped +astra_first_proj 9.16e-01 s Stopped +astra_init 1.06e-03 s Stopped +scico_avg_proj 1.03e+00 s Stopped +scico_first_proj 1.04e+00 s Stopped +scico_init 1.00e+01 s Stopped + + """ @@ -99,9 +122,71 @@ Show projections. """ -fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 5)) +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) fig.show() + +""" +Time first back projection, which might include JIT overhead. +""" +timer = Timer() + +y = np.zeros(H.output_shape, dtype=np.float32) +y[num_angles // 3, det_count // 2] = 1.0 +y = jax.device_put(y) + +HTys = {} +for name, H in projectors.items(): + timer_label = f"{name}_first_BP" + timer.start(timer_label) + HTys[name] = H.T @ y + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + + +""" +Compute average time for back projection. +""" +num_repeats = 3 +for name, H in projectors.items(): + timer_label = f"{name}_avg_BP" + timer.start(timer_label) + for _ in range(num_repeats): + HTys[name] = H.T @ y + jax.block_until_ready(ys[name]) + timer.stop(timer_label) + timer.td[timer_label] /= num_repeats + +""" +Display back projection timing results. + + + +On our server, using the GPU: + + +Using the CPU: + + + +""" + +print(timer) + +""" +Show back projections of a single detector element, +i.e., a line. +""" + +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) +plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) +plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) +for ax_i in ax: + ax_i.set_xlim(2 * N / 5, N - 2 * N / 5) + ax_i.set_ylim(2 * N / 5, N - 2 * N / 5) +fig.show() + + input("\nWaiting for input to close figures and exit") diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py index 6f53e9ff0..11e0fdf05 100644 --- a/scico/linop/_xray.py +++ b/scico/linop/_xray.py @@ -17,13 +17,25 @@ import jax.numpy as jnp from jax.typing import ArrayLike +from scico.typing import Shape + from ._linop import LinearOperator class XRayProject(LinearOperator): - """options to select type of projection""" - - def __init__(self, projector): + """X-ray projection operator. + + Wraps an X-ray projector object in a SCICO + :class:`LinearOperator`. + """ + + def __init__(self, projector: object): + r""" + Args: + projector: instance of an X-ray projector object to wrap, + currently the only option is + :class:`ParallelFixedAxis2dProjector` + """ self._eval = projector.project super().__init__( @@ -33,9 +45,20 @@ def __init__(self, projector): class ParallelFixedAxis2dProjector: - """Parallel ray, single axis, 2D X-ray projector""" - - def __init__(self, im_shape, angles, det_length=None, dither=True): + """Parallel ray, single axis, 2D X-ray projector.""" + + def __init__( + self, im_shape: Shape, angles: ArrayLike, det_length: int = None, do_dithering: bool = True + ): + r""" + Args: + im_shape: Shape of input array. + angles: (num_angles,) array of angles in radians. + det_length: Length of detector, in ``None``, defaults to the + length of diagonal of `im_shape`. + do_dither: If ``True`` randomly shift pixel locations to + reduce projection artifacts caused by aliasing. + """ self.im_shape = im_shape self.angles = angles @@ -51,15 +74,9 @@ def __init__(self, im_shape, angles, det_length=None, dither=True): @jax.vmap def compute_inds(angle: float) -> ArrayLike: - # fast, but does not allow easy dithering - # dydx = jnp.stack((jnp.cos(angle), jnp.sin(angle))) - # Px0 = jnp.dot(x0, dydx) - # Px = ( - # Px0 - # + dydx[0] * jnp.arange(im_shape[0])[:, jnp.newaxis] - # + dydx[1] * jnp.arange(im_shape[1])[jnp.newaxis, :] - # ) - + """Project pixel positions on to a detector at the given + angle, determine which detector element they contribute to. + """ x = jnp.stack( jnp.meshgrid( *( @@ -72,7 +89,7 @@ def compute_inds(angle: float) -> ArrayLike: ) # dither - if dither: + if do_dithering: key = jax.random.PRNGKey(0) x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) @@ -92,10 +109,12 @@ def compute_inds(angle: float) -> ArrayLike: @partial(jax.vmap, in_axes=(None, 0)) def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: + """Compute the projection at a single angle.""" return jnp.zeros(det_length).at[inds].add(im) @jax.jit def project(im: ArrayLike) -> ArrayLike: + """Compute the projection for all angles.""" return project_inds(im, inds) self.project = project diff --git a/scico/util.py b/scico/util.py index 2c3cf44c1..d57c8efc6 100644 --- a/scico/util.py +++ b/scico/util.py @@ -395,7 +395,6 @@ def __str__(self) -> str: s += "-" * (lfldln + 25) + "\n" # Construct table of timer details for lbl in sorted(self.t0): - print(lbl) td = self.td[lbl] if self.t0[lbl] is None: ts = " Stopped" From 3f623b6a6d7bc4ea21478e03457ae7f448470c84 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:12:46 -0600 Subject: [PATCH 04/13] Add test --- scico/test/linop/test_xray.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 scico/test/linop/test_xray.py diff --git a/scico/test/linop/test_xray.py b/scico/test/linop/test_xray.py new file mode 100644 index 000000000..7733158dd --- /dev/null +++ b/scico/test/linop/test_xray.py @@ -0,0 +1,26 @@ +import jax.numpy as jnp + +from scico.linop import ParallelFixedAxis2dProjector, XRayProject + + +def test_apply(): + im_shape = (12, 13) + num_angles = 10 + x = jnp.ones(im_shape) + + angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) + + # general projection + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles)) + y = H @ x + assert y.shape[0] == (num_angles) + + # fixed det_length + det_length = 14 + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, det_length=det_length)) + y = H @ x + assert y.shape[1] == det_length + + # dither off + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, do_dithering=False)) + y = H @ x From 26f721f6402bf76c5cdd47f216c2a164022cd45f Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:18:26 -0600 Subject: [PATCH 05/13] Add timing results to example --- examples/scripts/ct_projector_comparison.py | 29 ++++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index c8d150881..22d18bb0e 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -89,8 +89,8 @@ Display timing results. On our server, the SCICO projection is more than twice -as fast as ASTRA when run on the GPU, and about about -10% slower on the CPU. +as fast as ASTRA when both are run on the GPU, and about +10% slower when both are run the CPU. On our server, using the GPU: Label Accum. Current @@ -111,9 +111,6 @@ scico_avg_proj 1.03e+00 s Stopped scico_first_proj 1.04e+00 s Stopped scico_init 1.00e+01 s Stopped - - - """ print(timer) @@ -162,15 +159,27 @@ """ Display back projection timing results. - +On our server, the SCICO back projection is slow +the first time it is run, probably due to JIT overhead. +After the first run, it is an order of magnitude +faster than ASTRA when both are run on the GPU, +and about three times faster when both are run on the CPU. On our server, using the GPU: - +Label Accum. Current +----------------------------------------- +astra_avg_BP 3.71e-02 s Stopped +astra_first_BP 4.20e-02 s Stopped +scico_avg_BP 1.05e-03 s Stopped +scico_first_BP 7.63e+00 s Stopped Using the CPU: - - - +Label Accum. Current +----------------------------------------- +astra_avg_BP 9.34e-01 s Stopped +astra_first_BP 9.39e-01 s Stopped +scico_avg_BP 2.62e-01 s Stopped +scico_first_BP 1.00e+01 s Stopped """ print(timer) From 879d4b955e5dba4384f53e94a61e17ff4b1c5bdc Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:21:08 -0600 Subject: [PATCH 06/13] Start to add new example --- data | 2 +- docs/source/examples.rst | 1 + examples/scripts/README.rst | 2 ++ examples/scripts/index.rst | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/data b/data index 80c35007d..da0039945 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 80c35007dc595fc6553b1420c3d282c6e1fb04c1 +Subproject commit da0039945cc77390c7d9b77330addeeb3592e76d diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 368d8fa98..c6cc70e9a 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -35,6 +35,7 @@ Computed Tomography examples/ct_astra_modl_train_foam2 examples/ct_astra_odp_train_foam2 examples/ct_astra_unet_train_foam2 + examples/ct_projector_comparison Deconvolution diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index f82dd3aa7..3910e9671 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -35,6 +35,8 @@ Computed Tomography CT Training and Reconstructions with ODP `ct_astra_unet_train_foam2.py `_ CT Training and Reconstructions with UNet + `ct_projector_comparison.py `_ + X-ray Projector Comparison Deconvolution diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 683bf0893..f03f8fa28 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -22,6 +22,7 @@ Computed Tomography - ct_astra_modl_train_foam2.py - ct_astra_odp_train_foam2.py - ct_astra_unet_train_foam2.py + - ct_projector_comparison.py Deconvolution From 84ba712da51acbd72e404e9f05e18e36789ec198 Mon Sep 17 00:00:00 2001 From: Michael McCann Date: Fri, 28 Jul 2023 15:27:38 -0600 Subject: [PATCH 07/13] Update data --- data | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data b/data index da0039945..ebceeb797 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit da0039945cc77390c7d9b77330addeeb3592e76d +Subproject commit ebceeb7973cfd88cf1a22b3c3dc78aa474529ecf From 463f40a41afa31ef0ed7a9193a835aca07ae6f79 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:35:15 -0600 Subject: [PATCH 08/13] Address mypy --- scico/linop/_xray.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py index 11e0fdf05..7704d4bdc 100644 --- a/scico/linop/_xray.py +++ b/scico/linop/_xray.py @@ -10,6 +10,7 @@ X-ray projector classes. """ from functools import partial +from typing import Optional import numpy as np @@ -29,7 +30,7 @@ class XRayProject(LinearOperator): :class:`LinearOperator`. """ - def __init__(self, projector: object): + def __init__(self, projector): r""" Args: projector: instance of an X-ray projector object to wrap, @@ -48,7 +49,11 @@ class ParallelFixedAxis2dProjector: """Parallel ray, single axis, 2D X-ray projector.""" def __init__( - self, im_shape: Shape, angles: ArrayLike, det_length: int = None, do_dithering: bool = True + self, + im_shape: Shape, + angles: ArrayLike, + det_length: Optional[int] = None, + do_dithering: bool = True, ): r""" Args: From 7a2cc3a7906c2a1f6fd0ccad601018923198f461 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 28 Jul 2023 15:36:04 -0600 Subject: [PATCH 09/13] Address isort --- scico/linop/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index d01708800..0c14de950 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -18,8 +18,8 @@ from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator from ._stack import DiagonalStack, VerticalStack -from ._xray import XRayProject, ParallelFixedAxis2dProjector from ._util import jacobian, operator_norm, power_iteration, valid_adjoint +from ._xray import ParallelFixedAxis2dProjector, XRayProject __all__ = [ "CircularConvolve", From c63b6478280eb12b8a7355e885076e5f0abdbae1 Mon Sep 17 00:00:00 2001 From: Michael McCann Date: Fri, 28 Jul 2023 15:46:27 -0600 Subject: [PATCH 10/13] Try to fix tables in notebook --- data | 2 +- examples/scripts/ct_projector_comparison.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/data b/data index ebceeb797..464c2a74d 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit ebceeb7973cfd88cf1a22b3c3dc78aa474529ecf +Subproject commit 464c2a74d7d2a03844eb3d889135e4ddca0e4645 diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py index 22d18bb0e..ab8c43cfa 100644 --- a/examples/scripts/ct_projector_comparison.py +++ b/examples/scripts/ct_projector_comparison.py @@ -93,6 +93,7 @@ 10% slower when both are run the CPU. On our server, using the GPU: +``` Label Accum. Current ------------------------------------------- astra_avg_proj 4.62e-02 s Stopped @@ -101,8 +102,10 @@ scico_avg_proj 1.61e-02 s Stopped scico_first_proj 2.95e-02 s Stopped scico_init 1.37e+01 s Stopped +``` Using the CPU: +``` Label Accum. Current ------------------------------------------- astra_avg_proj 9.11e-01 s Stopped @@ -111,6 +114,7 @@ scico_avg_proj 1.03e+00 s Stopped scico_first_proj 1.04e+00 s Stopped scico_init 1.00e+01 s Stopped +``` """ print(timer) @@ -166,20 +170,24 @@ and about three times faster when both are run on the CPU. On our server, using the GPU: +``` Label Accum. Current ----------------------------------------- astra_avg_BP 3.71e-02 s Stopped astra_first_BP 4.20e-02 s Stopped scico_avg_BP 1.05e-03 s Stopped scico_first_BP 7.63e+00 s Stopped +``` Using the CPU: +``` Label Accum. Current ----------------------------------------- astra_avg_BP 9.34e-01 s Stopped astra_first_BP 9.39e-01 s Stopped scico_avg_BP 2.62e-01 s Stopped scico_first_BP 1.00e+01 s Stopped +``` """ print(timer) From 8b35d851e365557695c1a77b2ea6ce9ad65e0e99 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 1 Sep 2023 15:58:07 -0600 Subject: [PATCH 11/13] Update submodule --- data | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data b/data index 464c2a74d..0d9f1fef8 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 464c2a74d7d2a03844eb3d889135e4ddca0e4645 +Subproject commit 0d9f1fef8df6eebb98d154e1e6d1ab8357914a88 From 83d129f58b082abef8165b59fac69864c2eb4709 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 27 Sep 2023 12:17:19 -0600 Subject: [PATCH 12/13] Rename parameter --- scico/linop/_xray.py | 6 +++--- scico/test/linop/test_xray.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scico/linop/_xray.py b/scico/linop/_xray.py index 7704d4bdc..40c649cf4 100644 --- a/scico/linop/_xray.py +++ b/scico/linop/_xray.py @@ -53,7 +53,7 @@ def __init__( im_shape: Shape, angles: ArrayLike, det_length: Optional[int] = None, - do_dithering: bool = True, + dither: bool = True, ): r""" Args: @@ -61,7 +61,7 @@ def __init__( angles: (num_angles,) array of angles in radians. det_length: Length of detector, in ``None``, defaults to the length of diagonal of `im_shape`. - do_dither: If ``True`` randomly shift pixel locations to + dither: If ``True`` randomly shift pixel locations to reduce projection artifacts caused by aliasing. """ self.im_shape = im_shape @@ -94,7 +94,7 @@ def compute_inds(angle: float) -> ArrayLike: ) # dither - if do_dithering: + if dither: key = jax.random.PRNGKey(0) x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) diff --git a/scico/test/linop/test_xray.py b/scico/test/linop/test_xray.py index 7733158dd..bb827988f 100644 --- a/scico/test/linop/test_xray.py +++ b/scico/test/linop/test_xray.py @@ -22,5 +22,5 @@ def test_apply(): assert y.shape[1] == det_length # dither off - H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, do_dithering=False)) + H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, dither=False)) y = H @ x From 3757233c6627b378df64fa7417eb8c97fa020ef0 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 27 Sep 2023 12:40:32 -0600 Subject: [PATCH 13/13] Update ray syntax to get tests passing in CI --- scico/test/test_ray_tune.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index 682592934..dde5b1d37 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -7,19 +7,18 @@ try: import ray - from scico.ray import report, tune + from scico.ray import train, tune ray.init(num_cpus=1) except ImportError as e: pytest.skip("ray.tune not installed", allow_module_level=True) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_random_run(): - def eval_params(config, reporter): + def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - reporter(cost=cost) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -40,12 +39,11 @@ def eval_params(config, reporter): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_random_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -66,12 +64,11 @@ def eval_params(config): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_run(): - def eval_params(config, reporter): + def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - reporter(cost=cost) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -90,12 +87,11 @@ def eval_params(config, reporter): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_tune(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -115,12 +111,11 @@ def eval_params(config): assert np.abs(best_config["y"] - 0.5) < 0.25 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_hyperopt_tune_alt_init(): def eval_params(config): x, y = config["x"], config["y"] cost = x**2 + (y - 0.5) ** 2 - report({"cost": cost}) + train.report({"cost": cost}) config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} tuner = tune.Tuner(