Skip to content

Commit

Permalink
Add a jax-based X-ray projector (#433)
Browse files Browse the repository at this point in the history
* Add 2d projector and code to time it

* Clean up

* Add back projection

* Add test

* Add timing results to example

* Start to add new example

* Update data

* Address mypy

* Address isort

* Try to fix tables in notebook

* Update submodule

* Rename parameter

* Update ray syntax to get tests passing in CI

---------

Co-authored-by: Michael McCann <[email protected]>
Co-authored-by: Michael McCann <[email protected]>
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
4 people authored Sep 27, 2023
1 parent 6fe8536 commit 085bbcc
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 14 deletions.
2 changes: 1 addition & 1 deletion data
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Computed Tomography
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison


Deconvolution
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Computed Tomography
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Projector Comparison


Deconvolution
Expand Down
209 changes: 209 additions & 0 deletions examples/scripts/ct_projector_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.


r"""
X-ray Projector Comparison
==========================
This example compares SCICO's native X-ray projection algorithm
to that of the ASTRA Toolbox.
"""

import numpy as np

import jax
import jax.numpy as jnp

from xdesign import Foam, discrete_phantom

from scico import plot
from scico.linop import ParallelFixedAxis2dProjector, XRayProject
from scico.linop.radon_astra import TomographicProjector
from scico.util import Timer

"""
Create a ground truth image.
"""

N = 512


det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt)

"""
Time projector instantiation.
"""

num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)


timer = Timer()

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles))
timer.stop("scico_init")

timer.start("astra_init")
projectors["astra"] = TomographicProjector(
(N, N), detector_spacing=1.0, det_count=det_count, angles=angles
)
timer.stop("astra_init")

"""
Time first projector application, which might include JIT overhead.
"""

ys = {}
for name, H in projectors.items():
timer_label = f"{name}_first_proj"
timer.start(timer_label)
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)


"""
Compute average time for a projector application.
"""

num_repeats = 3
for name, H in projectors.items():
timer_label = f"{name}_avg_proj"
timer.start(timer_label)
for _ in range(num_repeats):
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)
timer.td[timer_label] /= num_repeats

"""
Display timing results.
On our server, the SCICO projection is more than twice
as fast as ASTRA when both are run on the GPU, and about
10% slower when both are run the CPU.
On our server, using the GPU:
```
Label Accum. Current
-------------------------------------------
astra_avg_proj 4.62e-02 s Stopped
astra_first_proj 6.92e-02 s Stopped
astra_init 1.36e-03 s Stopped
scico_avg_proj 1.61e-02 s Stopped
scico_first_proj 2.95e-02 s Stopped
scico_init 1.37e+01 s Stopped
```
Using the CPU:
```
Label Accum. Current
-------------------------------------------
astra_avg_proj 9.11e-01 s Stopped
astra_first_proj 9.16e-01 s Stopped
astra_init 1.06e-03 s Stopped
scico_avg_proj 1.03e+00 s Stopped
scico_first_proj 1.04e+00 s Stopped
scico_init 1.00e+01 s Stopped
```
"""

print(timer)

"""
Show projections.
"""

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1])
fig.show()


"""
Time first back projection, which might include JIT overhead.
"""
timer = Timer()

y = np.zeros(H.output_shape, dtype=np.float32)
y[num_angles // 3, det_count // 2] = 1.0
y = jax.device_put(y)

HTys = {}
for name, H in projectors.items():
timer_label = f"{name}_first_BP"
timer.start(timer_label)
HTys[name] = H.T @ y
jax.block_until_ready(ys[name])
timer.stop(timer_label)


"""
Compute average time for back projection.
"""
num_repeats = 3
for name, H in projectors.items():
timer_label = f"{name}_avg_BP"
timer.start(timer_label)
for _ in range(num_repeats):
HTys[name] = H.T @ y
jax.block_until_ready(ys[name])
timer.stop(timer_label)
timer.td[timer_label] /= num_repeats

"""
Display back projection timing results.
On our server, the SCICO back projection is slow
the first time it is run, probably due to JIT overhead.
After the first run, it is an order of magnitude
faster than ASTRA when both are run on the GPU,
and about three times faster when both are run on the CPU.
On our server, using the GPU:
```
Label Accum. Current
-----------------------------------------
astra_avg_BP 3.71e-02 s Stopped
astra_first_BP 4.20e-02 s Stopped
scico_avg_BP 1.05e-03 s Stopped
scico_first_BP 7.63e+00 s Stopped
```
Using the CPU:
```
Label Accum. Current
-----------------------------------------
astra_avg_BP 9.34e-01 s Stopped
astra_first_BP 9.39e-01 s Stopped
scico_avg_BP 2.62e-01 s Stopped
scico_first_BP 1.00e+01 s Stopped
```
"""

print(timer)

"""
Show back projections of a single detector element,
i.e., a line.
"""

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0])
plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1])
for ax_i in ax:
ax_i.set_xlim(2 * N / 5, N - 2 * N / 5)
ax_i.set_ylim(2 * N / 5, N - 2 * N / 5)
fig.show()


input("\nWaiting for input to close figures and exit")
1 change: 1 addition & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Computed Tomography
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py


Deconvolution
Expand Down
3 changes: 3 additions & 0 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from ._xray import ParallelFixedAxis2dProjector, XRayProject

__all__ = [
"CircularConvolve",
Expand All @@ -38,6 +39,8 @@
"Sum",
"Transpose",
"LinearOperator",
"XRayProject",
"ParallelFixedAxis2dProjector",
"ComposedLinearOperator",
"linop_from_function",
"operator_norm",
Expand Down
125 changes: 125 additions & 0 deletions scico/linop/_xray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.


"""
X-ray projector classes.
"""
from functools import partial
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from scico.typing import Shape

from ._linop import LinearOperator


class XRayProject(LinearOperator):
"""X-ray projection operator.
Wraps an X-ray projector object in a SCICO
:class:`LinearOperator`.
"""

def __init__(self, projector):
r"""
Args:
projector: instance of an X-ray projector object to wrap,
currently the only option is
:class:`ParallelFixedAxis2dProjector`
"""
self._eval = projector.project

super().__init__(
input_shape=projector.im_shape,
output_shape=(len(projector.angles), *projector.det_shape),
)


class ParallelFixedAxis2dProjector:
"""Parallel ray, single axis, 2D X-ray projector."""

def __init__(
self,
im_shape: Shape,
angles: ArrayLike,
det_length: Optional[int] = None,
dither: bool = True,
):
r"""
Args:
im_shape: Shape of input array.
angles: (num_angles,) array of angles in radians.
det_length: Length of detector, in ``None``, defaults to the
length of diagonal of `im_shape`.
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
self.im_shape = im_shape
self.angles = angles

im_shape = np.array(im_shape)

x0 = -(im_shape - 1) / 2

if det_length is None:
det_length = int(np.ceil(np.linalg.norm(im_shape)))
self.det_shape = (det_length,)

y0 = -det_length / 2

@jax.vmap
def compute_inds(angle: float) -> ArrayLike:
"""Project pixel positions on to a detector at the given
angle, determine which detector element they contribute to.
"""
x = jnp.stack(
jnp.meshgrid(
*(
jnp.arange(shape_i) * step_i + start_i
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape)
),
indexing="ij",
),
axis=-1,
)

# dither
if dither:
key = jax.random.PRNGKey(0)
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5)

# project
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle)

# quantize
inds = jnp.floor((Px - y0)).astype(int)

# map negative inds to y_size, which is out of bounds and will be ignored
# otherwise they index from the end like x[-1]
inds = jnp.where(inds < 0, det_length, inds)

return inds

inds = compute_inds(angles) # (len(angles), *im_shape)

@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
"""Compute the projection at a single angle."""
return jnp.zeros(det_length).at[inds].add(im)

@jax.jit
def project(im: ArrayLike) -> ArrayLike:
"""Compute the projection for all angles."""
return project_inds(im, inds)

self.project = project
Loading

0 comments on commit 085bbcc

Please sign in to comment.