Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a jax-based X-ray projector #433

Merged
merged 14 commits into from
Sep 27, 2023
2 changes: 1 addition & 1 deletion data
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Computed Tomography
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison


Deconvolution
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Computed Tomography
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Projector Comparison


Deconvolution
Expand Down
209 changes: 209 additions & 0 deletions examples/scripts/ct_projector_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.


r"""
X-ray Projector Comparison
==========================

This example compares SCICO's native X-ray projection algorithm
to that of the ASTRA Toolbox.
"""

import numpy as np

import jax
import jax.numpy as jnp

from xdesign import Foam, discrete_phantom

from scico import plot
from scico.linop import ParallelFixedAxis2dProjector, XRayProject
from scico.linop.radon_astra import TomographicProjector
from scico.util import Timer

"""
Create a ground truth image.
"""

N = 512


det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt)

"""
Time projector instantiation.
"""

num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)


timer = Timer()

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles))
timer.stop("scico_init")

timer.start("astra_init")
projectors["astra"] = TomographicProjector(
(N, N), detector_spacing=1.0, det_count=det_count, angles=angles
)
timer.stop("astra_init")

"""
Time first projector application, which might include JIT overhead.
"""

ys = {}
for name, H in projectors.items():
timer_label = f"{name}_first_proj"
timer.start(timer_label)
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)


"""
Compute average time for a projector application.
"""

num_repeats = 3
for name, H in projectors.items():
timer_label = f"{name}_avg_proj"
timer.start(timer_label)
for _ in range(num_repeats):
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)
timer.td[timer_label] /= num_repeats

"""
Display timing results.

On our server, the SCICO projection is more than twice
as fast as ASTRA when both are run on the GPU, and about
10% slower when both are run the CPU.

On our server, using the GPU:
```
Label Accum. Current
-------------------------------------------
astra_avg_proj 4.62e-02 s Stopped
astra_first_proj 6.92e-02 s Stopped
astra_init 1.36e-03 s Stopped
scico_avg_proj 1.61e-02 s Stopped
scico_first_proj 2.95e-02 s Stopped
scico_init 1.37e+01 s Stopped
```

Using the CPU:
```
Label Accum. Current
-------------------------------------------
astra_avg_proj 9.11e-01 s Stopped
astra_first_proj 9.16e-01 s Stopped
astra_init 1.06e-03 s Stopped
scico_avg_proj 1.03e+00 s Stopped
scico_first_proj 1.04e+00 s Stopped
scico_init 1.00e+01 s Stopped
```
"""

print(timer)

"""
Show projections.
"""

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1])
fig.show()


"""
Time first back projection, which might include JIT overhead.
"""
timer = Timer()

y = np.zeros(H.output_shape, dtype=np.float32)
y[num_angles // 3, det_count // 2] = 1.0
y = jax.device_put(y)

HTys = {}
for name, H in projectors.items():
timer_label = f"{name}_first_BP"
timer.start(timer_label)
HTys[name] = H.T @ y
jax.block_until_ready(ys[name])
timer.stop(timer_label)


"""
Compute average time for back projection.
"""
num_repeats = 3
for name, H in projectors.items():
timer_label = f"{name}_avg_BP"
timer.start(timer_label)
for _ in range(num_repeats):
HTys[name] = H.T @ y
jax.block_until_ready(ys[name])
timer.stop(timer_label)
timer.td[timer_label] /= num_repeats

"""
Display back projection timing results.

On our server, the SCICO back projection is slow
the first time it is run, probably due to JIT overhead.
After the first run, it is an order of magnitude
faster than ASTRA when both are run on the GPU,
and about three times faster when both are run on the CPU.

On our server, using the GPU:
```
Label Accum. Current
-----------------------------------------
astra_avg_BP 3.71e-02 s Stopped
astra_first_BP 4.20e-02 s Stopped
scico_avg_BP 1.05e-03 s Stopped
scico_first_BP 7.63e+00 s Stopped
```

Using the CPU:
```
Label Accum. Current
-----------------------------------------
astra_avg_BP 9.34e-01 s Stopped
astra_first_BP 9.39e-01 s Stopped
scico_avg_BP 2.62e-01 s Stopped
scico_first_BP 1.00e+01 s Stopped
```
"""

print(timer)

"""
Show back projections of a single detector element,
i.e., a line.
"""

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0])
plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1])
for ax_i in ax:
ax_i.set_xlim(2 * N / 5, N - 2 * N / 5)
ax_i.set_ylim(2 * N / 5, N - 2 * N / 5)
fig.show()


input("\nWaiting for input to close figures and exit")
1 change: 1 addition & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Computed Tomography
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py


Deconvolution
Expand Down
3 changes: 3 additions & 0 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from ._xray import ParallelFixedAxis2dProjector, XRayProject

__all__ = [
"CircularConvolve",
Expand All @@ -38,6 +39,8 @@
"Sum",
"Transpose",
"LinearOperator",
"XRayProject",
"ParallelFixedAxis2dProjector",
"ComposedLinearOperator",
"linop_from_function",
"operator_norm",
Expand Down
125 changes: 125 additions & 0 deletions scico/linop/_xray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.


"""
X-ray projector classes.
"""
from functools import partial
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from scico.typing import Shape

from ._linop import LinearOperator


class XRayProject(LinearOperator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is not consistent with existing CT projector classes: worth discussion.

Copy link
Contributor Author

@Michael-T-McCann Michael-T-McCann Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussion: RadonTransform here, change the others in a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New discussion: XRayTransform

"""X-ray projection operator.

Wraps an X-ray projector object in a SCICO
:class:`LinearOperator`.
"""

def __init__(self, projector):
r"""
Args:
projector: instance of an X-ray projector object to wrap,
currently the only option is
:class:`ParallelFixedAxis2dProjector`
"""
self._eval = projector.project

super().__init__(
input_shape=projector.im_shape,
output_shape=(len(projector.angles), *projector.det_shape),
)


class ParallelFixedAxis2dProjector:
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
"""Parallel ray, single axis, 2D X-ray projector."""

def __init__(
self,
im_shape: Shape,
angles: ArrayLike,
det_length: Optional[int] = None,
dither: bool = True,
):
r"""
Args:
im_shape: Shape of input array.
angles: (num_angles,) array of angles in radians.
det_length: Length of detector, in ``None``, defaults to the
length of diagonal of `im_shape`.
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
self.im_shape = im_shape
self.angles = angles

im_shape = np.array(im_shape)

x0 = -(im_shape - 1) / 2

if det_length is None:
det_length = int(np.ceil(np.linalg.norm(im_shape)))
self.det_shape = (det_length,)

y0 = -det_length / 2

@jax.vmap
def compute_inds(angle: float) -> ArrayLike:
"""Project pixel positions on to a detector at the given
angle, determine which detector element they contribute to.
"""
x = jnp.stack(
jnp.meshgrid(
*(
jnp.arange(shape_i) * step_i + start_i
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape)
),
indexing="ij",
),
axis=-1,
)

# dither
if dither:
key = jax.random.PRNGKey(0)
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5)

# project
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle)

# quantize
inds = jnp.floor((Px - y0)).astype(int)

# map negative inds to y_size, which is out of bounds and will be ignored
# otherwise they index from the end like x[-1]
inds = jnp.where(inds < 0, det_length, inds)

return inds

inds = compute_inds(angles) # (len(angles), *im_shape)

@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
"""Compute the projection at a single angle."""
return jnp.zeros(det_length).at[inds].add(im)

@jax.jit
def project(im: ArrayLike) -> ArrayLike:
"""Compute the projection for all angles."""
return project_inds(im, inds)

self.project = project
Loading
Loading