-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a jax-based X-ray projector #433
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
18e3918
Add 2d projector and code to time it
099d6e4
Clean up
542c299
Add back projection
Michael-T-McCann 3f623b6
Add test
Michael-T-McCann 26f721f
Add timing results to example
Michael-T-McCann 879d4b9
Start to add new example
Michael-T-McCann 84ba712
Update data
Michael-T-McCann 463f40a
Address mypy
Michael-T-McCann 7a2cc3a
Address isort
Michael-T-McCann c63b647
Try to fix tables in notebook
Michael-T-McCann 8b35d85
Update submodule
bwohlberg 1708d09
Merge branch 'main' into mike/2D_CT_projector
bwohlberg 83d129f
Rename parameter
Michael-T-McCann 3757233
Update ray syntax to get tests passing in CI
Michael-T-McCann File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule data
updated
2 files
+492 −0 | notebooks/ct_projector_comparison.ipynb | |
+2 −1 | notebooks/index.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# This file is part of the SCICO package. Details of the copyright | ||
# and user license can be found in the 'LICENSE.txt' file distributed | ||
# with the package. | ||
|
||
|
||
r""" | ||
X-ray Projector Comparison | ||
========================== | ||
|
||
This example compares SCICO's native X-ray projection algorithm | ||
to that of the ASTRA Toolbox. | ||
""" | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from xdesign import Foam, discrete_phantom | ||
|
||
from scico import plot | ||
from scico.linop import ParallelFixedAxis2dProjector, XRayProject | ||
from scico.linop.radon_astra import TomographicProjector | ||
from scico.util import Timer | ||
|
||
""" | ||
Create a ground truth image. | ||
""" | ||
|
||
N = 512 | ||
|
||
|
||
det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) | ||
|
||
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) | ||
x_gt = jax.device_put(x_gt) | ||
|
||
""" | ||
Time projector instantiation. | ||
""" | ||
|
||
num_angles = 500 | ||
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) | ||
|
||
|
||
timer = Timer() | ||
|
||
projectors = {} | ||
timer.start("scico_init") | ||
projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles)) | ||
timer.stop("scico_init") | ||
|
||
timer.start("astra_init") | ||
projectors["astra"] = TomographicProjector( | ||
(N, N), detector_spacing=1.0, det_count=det_count, angles=angles | ||
) | ||
timer.stop("astra_init") | ||
|
||
""" | ||
Time first projector application, which might include JIT overhead. | ||
""" | ||
|
||
ys = {} | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_first_proj" | ||
timer.start(timer_label) | ||
ys[name] = H @ x_gt | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
|
||
|
||
""" | ||
Compute average time for a projector application. | ||
""" | ||
|
||
num_repeats = 3 | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_avg_proj" | ||
timer.start(timer_label) | ||
for _ in range(num_repeats): | ||
ys[name] = H @ x_gt | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
timer.td[timer_label] /= num_repeats | ||
|
||
""" | ||
Display timing results. | ||
|
||
On our server, the SCICO projection is more than twice | ||
as fast as ASTRA when both are run on the GPU, and about | ||
10% slower when both are run the CPU. | ||
|
||
On our server, using the GPU: | ||
``` | ||
Label Accum. Current | ||
------------------------------------------- | ||
astra_avg_proj 4.62e-02 s Stopped | ||
astra_first_proj 6.92e-02 s Stopped | ||
astra_init 1.36e-03 s Stopped | ||
scico_avg_proj 1.61e-02 s Stopped | ||
scico_first_proj 2.95e-02 s Stopped | ||
scico_init 1.37e+01 s Stopped | ||
``` | ||
|
||
Using the CPU: | ||
``` | ||
Label Accum. Current | ||
------------------------------------------- | ||
astra_avg_proj 9.11e-01 s Stopped | ||
astra_first_proj 9.16e-01 s Stopped | ||
astra_init 1.06e-03 s Stopped | ||
scico_avg_proj 1.03e+00 s Stopped | ||
scico_first_proj 1.04e+00 s Stopped | ||
scico_init 1.00e+01 s Stopped | ||
``` | ||
""" | ||
|
||
print(timer) | ||
|
||
""" | ||
Show projections. | ||
""" | ||
|
||
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) | ||
plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0]) | ||
plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1]) | ||
fig.show() | ||
|
||
|
||
""" | ||
Time first back projection, which might include JIT overhead. | ||
""" | ||
timer = Timer() | ||
|
||
y = np.zeros(H.output_shape, dtype=np.float32) | ||
y[num_angles // 3, det_count // 2] = 1.0 | ||
y = jax.device_put(y) | ||
|
||
HTys = {} | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_first_BP" | ||
timer.start(timer_label) | ||
HTys[name] = H.T @ y | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
|
||
|
||
""" | ||
Compute average time for back projection. | ||
""" | ||
num_repeats = 3 | ||
for name, H in projectors.items(): | ||
timer_label = f"{name}_avg_BP" | ||
timer.start(timer_label) | ||
for _ in range(num_repeats): | ||
HTys[name] = H.T @ y | ||
jax.block_until_ready(ys[name]) | ||
timer.stop(timer_label) | ||
timer.td[timer_label] /= num_repeats | ||
|
||
""" | ||
Display back projection timing results. | ||
|
||
On our server, the SCICO back projection is slow | ||
the first time it is run, probably due to JIT overhead. | ||
After the first run, it is an order of magnitude | ||
faster than ASTRA when both are run on the GPU, | ||
and about three times faster when both are run on the CPU. | ||
|
||
On our server, using the GPU: | ||
``` | ||
Label Accum. Current | ||
----------------------------------------- | ||
astra_avg_BP 3.71e-02 s Stopped | ||
astra_first_BP 4.20e-02 s Stopped | ||
scico_avg_BP 1.05e-03 s Stopped | ||
scico_first_BP 7.63e+00 s Stopped | ||
``` | ||
|
||
Using the CPU: | ||
``` | ||
Label Accum. Current | ||
----------------------------------------- | ||
astra_avg_BP 9.34e-01 s Stopped | ||
astra_first_BP 9.39e-01 s Stopped | ||
scico_avg_BP 2.62e-01 s Stopped | ||
scico_first_BP 1.00e+01 s Stopped | ||
``` | ||
""" | ||
|
||
print(timer) | ||
|
||
""" | ||
Show back projections of a single detector element, | ||
i.e., a line. | ||
""" | ||
|
||
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3)) | ||
plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0]) | ||
plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1]) | ||
for ax_i in ax: | ||
ax_i.set_xlim(2 * N / 5, N - 2 * N / 5) | ||
ax_i.set_ylim(2 * N / 5, N - 2 * N / 5) | ||
fig.show() | ||
|
||
|
||
input("\nWaiting for input to close figures and exit") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2020-2023 by SCICO Developers | ||
# All rights reserved. BSD 3-clause License. | ||
# This file is part of the SCICO package. Details of the copyright and | ||
# user license can be found in the 'LICENSE' file distributed with the | ||
# package. | ||
|
||
|
||
""" | ||
X-ray projector classes. | ||
""" | ||
from functools import partial | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax.typing import ArrayLike | ||
|
||
from scico.typing import Shape | ||
|
||
from ._linop import LinearOperator | ||
|
||
|
||
class XRayProject(LinearOperator): | ||
"""X-ray projection operator. | ||
|
||
Wraps an X-ray projector object in a SCICO | ||
:class:`LinearOperator`. | ||
""" | ||
|
||
def __init__(self, projector): | ||
r""" | ||
Args: | ||
projector: instance of an X-ray projector object to wrap, | ||
currently the only option is | ||
:class:`ParallelFixedAxis2dProjector` | ||
""" | ||
self._eval = projector.project | ||
|
||
super().__init__( | ||
input_shape=projector.im_shape, | ||
output_shape=(len(projector.angles), *projector.det_shape), | ||
) | ||
|
||
|
||
class ParallelFixedAxis2dProjector: | ||
Michael-T-McCann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Parallel ray, single axis, 2D X-ray projector.""" | ||
|
||
def __init__( | ||
self, | ||
im_shape: Shape, | ||
angles: ArrayLike, | ||
det_length: Optional[int] = None, | ||
dither: bool = True, | ||
): | ||
r""" | ||
Args: | ||
im_shape: Shape of input array. | ||
angles: (num_angles,) array of angles in radians. | ||
det_length: Length of detector, in ``None``, defaults to the | ||
length of diagonal of `im_shape`. | ||
dither: If ``True`` randomly shift pixel locations to | ||
reduce projection artifacts caused by aliasing. | ||
""" | ||
self.im_shape = im_shape | ||
self.angles = angles | ||
|
||
im_shape = np.array(im_shape) | ||
|
||
x0 = -(im_shape - 1) / 2 | ||
|
||
if det_length is None: | ||
det_length = int(np.ceil(np.linalg.norm(im_shape))) | ||
self.det_shape = (det_length,) | ||
|
||
y0 = -det_length / 2 | ||
|
||
@jax.vmap | ||
def compute_inds(angle: float) -> ArrayLike: | ||
"""Project pixel positions on to a detector at the given | ||
angle, determine which detector element they contribute to. | ||
""" | ||
x = jnp.stack( | ||
jnp.meshgrid( | ||
*( | ||
jnp.arange(shape_i) * step_i + start_i | ||
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape) | ||
), | ||
indexing="ij", | ||
), | ||
axis=-1, | ||
) | ||
|
||
# dither | ||
if dither: | ||
key = jax.random.PRNGKey(0) | ||
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5) | ||
|
||
# project | ||
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle) | ||
|
||
# quantize | ||
inds = jnp.floor((Px - y0)).astype(int) | ||
|
||
# map negative inds to y_size, which is out of bounds and will be ignored | ||
# otherwise they index from the end like x[-1] | ||
inds = jnp.where(inds < 0, det_length, inds) | ||
|
||
return inds | ||
|
||
inds = compute_inds(angles) # (len(angles), *im_shape) | ||
|
||
@partial(jax.vmap, in_axes=(None, 0)) | ||
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike: | ||
"""Compute the projection at a single angle.""" | ||
return jnp.zeros(det_length).at[inds].add(im) | ||
|
||
@jax.jit | ||
def project(im: ArrayLike) -> ArrayLike: | ||
"""Compute the projection for all angles.""" | ||
return project_inds(im, inds) | ||
|
||
self.project = project |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming is not consistent with existing CT projector classes: worth discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussion: RadonTransform here, change the others in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New discussion:
XRayTransform