-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a jax-based X-ray projector (#433)
* Add 2d projector and code to time it * Clean up * Add back projection * Add test * Add timing results to example * Start to add new example * Update data * Address mypy * Address isort * Try to fix tables in notebook * Update submodule * Rename parameter * Update ray syntax to get tests passing in CI --------- Co-authored-by: Michael McCann <[email protected]> Co-authored-by: Michael McCann <[email protected]> Co-authored-by: Brendt Wohlberg <[email protected]>
- Loading branch information
1 parent
6fe8536
commit 085bbcc
Showing
9 changed files
with
376 additions
and
14 deletions.
There are no files selected for viewing
Submodule data
updated
2 files
+492 −0 | notebooks/ct_projector_comparison.ipynb | |
+2 −1 | notebooks/index.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# This file is part of the SCICO package. Details of the copyright | ||
# and user license can be found in the 'LICENSE.txt' file distributed | ||
# with the package. | ||
|
||
|
||
r""" | ||
X-ray Projector Comparison | ||
========================== | ||
This example compares SCICO's native X-ray projection algorithm | ||
to that of the ASTRA Toolbox. | ||
""" | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from xdesign import Foam, discrete_phantom | ||
|
||
from scico import plot | ||
from scico.linop import ParallelFixedAxis2dProjector, XRayProject | ||
from scico.linop.radon_astra import TomographicProjector | ||
from scico.util import Timer | ||
|
||
""" | ||
Create a ground truth image. | ||
""" | ||
|
||
N = 512 | ||
|
||
|
||
det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) | ||
|
||
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) | ||
x_gt = jax.device_put(x_gt) | ||
|
||
""" | ||
Time projector instantiation. | ||
""" | ||
|
||
num_angles = 500 | ||
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) | ||
|
||
|
||
timer = Timer() | ||
|
||
projectors = {} | ||
timer.start("scico_init") | ||
projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles)) | ||
timer.stop("scico_init") | ||
|
||
timer.start("astra_init") | ||
projectors["astra"] = TomographicProjector( | ||
(N, N), detector_spacing=1.0, det_count=det_count, angles=angles | ||
) | ||
timer.stop("astra_init") | ||
|
||
""" | ||
Time first projector application, which might include JIT overhead. | ||
""" | ||
|
||
ys = {} | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_first_proj" | ||
timer.start(timer_label) | ||
ys[name] = H @ x_gt | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
|
||
|
||
""" | ||
Compute average time for a projector application. | ||
""" | ||
|
||
num_repeats = 3 | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_avg_proj" | ||
timer.start(timer_label) | ||
for _ in range(num_repeats): | ||
ys[name] = H @ x_gt | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
timer.td[timer_label] /= num_repeats | ||
|
||
""" | ||
Display timing results. | ||
On our server, the SCICO projection is more than twice | ||
as fast as ASTRA when both are run on the GPU, and about | ||
10% slower when both are run the CPU. | ||
On our server, using the GPU: | ||
``` | ||
Label Accum. Current | ||
------------------------------------------- | ||
astra_avg_proj 4.62e-02 s Stopped | ||
astra_first_proj 6.92e-02 s Stopped | ||
astra_init 1.36e-03 s Stopped | ||
scico_avg_proj 1.61e-02 s Stopped | ||
scico_first_proj 2.95e-02 s Stopped | ||
scico_init 1.37e+01 s Stopped | ||
``` | ||
Using the CPU: | ||
``` | ||
Label Accum. Current | ||
------------------------------------------- | ||
astra_avg_proj 9.11e-01 s Stopped | ||
astra_first_proj 9.16e-01 s Stopped | ||
astra_init 1.06e-03 s Stopped | ||
scico_avg_proj 1.03e+00 s Stopped | ||
scico_first_proj 1.04e+00 s Stopped | ||
scico_init 1.00e+01 s Stopped | ||
``` | ||
""" | ||
|
||
print(timer) | ||
|
||
""" | ||
Show projections. | ||
""" | ||
|
||
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) | ||
plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) | ||
plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) | ||
fig.show() | ||
|
||
|
||
""" | ||
Time first back projection, which might include JIT overhead. | ||
""" | ||
timer = Timer() | ||
|
||
y = np.zeros(H.output_shape, dtype=np.float32) | ||
y[num_angles // 3, det_count // 2] = 1.0 | ||
y = jax.device_put(y) | ||
|
||
HTys = {} | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_first_BP" | ||
timer.start(timer_label) | ||
HTys[name] = H.T @ y | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
|
||
|
||
""" | ||
Compute average time for back projection. | ||
""" | ||
num_repeats = 3 | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_avg_BP" | ||
timer.start(timer_label) | ||
for _ in range(num_repeats): | ||
HTys[name] = H.T @ y | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
timer.td[timer_label] /= num_repeats | ||
|
||
""" | ||
Display back projection timing results. | ||
On our server, the SCICO back projection is slow | ||
the first time it is run, probably due to JIT overhead. | ||
After the first run, it is an order of magnitude | ||
faster than ASTRA when both are run on the GPU, | ||
and about three times faster when both are run on the CPU. | ||
On our server, using the GPU: | ||
``` | ||
Label Accum. Current | ||
----------------------------------------- | ||
astra_avg_BP 3.71e-02 s Stopped | ||
astra_first_BP 4.20e-02 s Stopped | ||
scico_avg_BP 1.05e-03 s Stopped | ||
scico_first_BP 7.63e+00 s Stopped | ||
``` | ||
Using the CPU: | ||
``` | ||
Label Accum. Current | ||
----------------------------------------- | ||
astra_avg_BP 9.34e-01 s Stopped | ||
astra_first_BP 9.39e-01 s Stopped | ||
scico_avg_BP 2.62e-01 s Stopped | ||
scico_first_BP 1.00e+01 s Stopped | ||
``` | ||
""" | ||
|
||
print(timer) | ||
|
||
""" | ||
Show back projections of a single detector element, | ||
i.e., a line. | ||
""" | ||
|
||
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) | ||
plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) | ||
plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) | ||
for ax_i in ax: | ||
ax_i.set_xlim(2 * N / 5, N - 2 * N / 5) | ||
ax_i.set_ylim(2 * N / 5, N - 2 * N / 5) | ||
fig.show() | ||
|
||
|
||
input("\nWaiting for input to close figures and exit") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2020-2023 by SCICO Developers | ||
# All rights reserved. BSD 3-clause License. | ||
# This file is part of the SCICO package. Details of the copyright and | ||
# user license can be found in the 'LICENSE' file distributed with the | ||
# package. | ||
|
||
|
||
""" | ||
X-ray projector classes. | ||
""" | ||
from functools import partial | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax.typing import ArrayLike | ||
|
||
from scico.typing import Shape | ||
|
||
from ._linop import LinearOperator | ||
|
||
|
||
class XRayProject(LinearOperator): | ||
"""X-ray projection operator. | ||
Wraps an X-ray projector object in a SCICO | ||
:class:`LinearOperator`. | ||
""" | ||
|
||
def __init__(self, projector): | ||
r""" | ||
Args: | ||
projector: instance of an X-ray projector object to wrap, | ||
currently the only option is | ||
:class:`ParallelFixedAxis2dProjector` | ||
""" | ||
self._eval = projector.project | ||
|
||
super().__init__( | ||
input_shape=projector.im_shape, | ||
output_shape=(len(projector.angles), *projector.det_shape), | ||
) | ||
|
||
|
||
class ParallelFixedAxis2dProjector: | ||
"""Parallel ray, single axis, 2D X-ray projector.""" | ||
|
||
def __init__( | ||
self, | ||
im_shape: Shape, | ||
angles: ArrayLike, | ||
det_length: Optional[int] = None, | ||
dither: bool = True, | ||
): | ||
r""" | ||
Args: | ||
im_shape: Shape of input array. | ||
angles: (num_angles,) array of angles in radians. | ||
det_length: Length of detector, in ``None``, defaults to the | ||
length of diagonal of `im_shape`. | ||
dither: If ``True`` randomly shift pixel locations to | ||
reduce projection artifacts caused by aliasing. | ||
""" | ||
self.im_shape = im_shape | ||
self.angles = angles | ||
|
||
im_shape = np.array(im_shape) | ||
|
||
x0 = -(im_shape - 1) / 2 | ||
|
||
if det_length is None: | ||
det_length = int(np.ceil(np.linalg.norm(im_shape))) | ||
self.det_shape = (det_length,) | ||
|
||
y0 = -det_length / 2 | ||
|
||
@jax.vmap | ||
def compute_inds(angle: float) -> ArrayLike: | ||
"""Project pixel positions on to a detector at the given | ||
angle, determine which detector element they contribute to. | ||
""" | ||
x = jnp.stack( | ||
jnp.meshgrid( | ||
*( | ||
jnp.arange(shape_i) * step_i + start_i | ||
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape) | ||
), | ||
indexing="ij", | ||
), | ||
axis=-1, | ||
) | ||
|
||
# dither | ||
if dither: | ||
key = jax.random.PRNGKey(0) | ||
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) | ||
|
||
# project | ||
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle) | ||
|
||
# quantize | ||
inds = jnp.floor((Px - y0)).astype(int) | ||
|
||
# map negative inds to y_size, which is out of bounds and will be ignored | ||
# otherwise they index from the end like x[-1] | ||
inds = jnp.where(inds < 0, det_length, inds) | ||
|
||
return inds | ||
|
||
inds = compute_inds(angles) # (len(angles), *im_shape) | ||
|
||
@partial(jax.vmap, in_axes=(None, 0)) | ||
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: | ||
"""Compute the projection at a single angle.""" | ||
return jnp.zeros(det_length).at[inds].add(im) | ||
|
||
@jax.jit | ||
def project(im: ArrayLike) -> ArrayLike: | ||
"""Compute the projection for all angles.""" | ||
return project_inds(im, inds) | ||
|
||
self.project = project |
Oops, something went wrong.