Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the SCICO x-ray projector #473

Merged
merged 10 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 25 additions & 26 deletions examples/scripts/ct_projector_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,47 +119,46 @@
"""
Display timing results.

On our server, the SCICO projection is more than twice as fast as ASTRA
when both are run on the GPU, and about 10% slower when both are run the
CPU. The SCICO back projection is slow the first time it is run, probably
due to JIT overhead. After the first run, it is an order of magnitude
faster than ASTRA when both are run on the GPU, and about three times
faster when both are run on the CPU.
On our server, when using the GPU, the SCICO projector (both forward
and backward) is faster than ASTRA. When using the CPU, it is slower
for forward projection and faster for back projection. The SCICO object
initialization and first back projection are slow due to JIT
overhead.

On our server, using the GPU:
```
init astra 1.36e-03 s
init scico 1.37e+01 s
init astra 4.81e-02 s
init scico 2.53e-01 s

first fwd astra 6.92e-02 s
first fwd scico 2.95e-02 s
first fwd astra 4.44e-02 s
first fwd scico 2.82e-02 s

first back astra 4.20e-02 s
first back scico 7.63e+00 s
first back astra 3.31e-02 s
first back scico 2.80e-01 s

avg fwd astra 4.62e-02 s
avg fwd scico 1.61e-02 s
avg fwd astra 4.76e-02 s
avg fwd scico 2.83e-02 s

avg back astra 3.71e-02 s
avg back scico 1.05e-03 s
avg back astra 3.96e-02 s
avg back scico 6.80e-04 s
```

Using the CPU:
```
init astra 1.06e-03 s
init scico 1.00e+01 s
init astra 1.72e-02 s
init scico 2.88e+00 s

first fwd astra 9.16e-01 s
first fwd scico 1.04e+00 s
first fwd astra 1.02e+00 s
first fwd scico 2.40e+00 s

first back astra 9.39e-01 s
first back scico 1.00e+01 s
first back astra 1.03e+00 s
first back scico 3.53e+00 s

avg fwd astra 9.11e-01 s
avg fwd scico 1.03e+00 s
avg fwd astra 1.03e+00 s
avg fwd scico 2.54e+00 s

avg back astra 9.34e-01 s
avg back scico 2.62e-01 s
avg back astra 1.01e+00 s
avg back scico 5.98e-01 s
```
"""

Expand Down
148 changes: 91 additions & 57 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, projector):

super().__init__(
input_shape=projector.im_shape,
output_shape=(len(projector.angles), *projector.det_shape),
output_shape=(len(projector.angles), projector.det_count),
)


Expand All @@ -56,7 +56,6 @@ def __init__(
im_shape: Shape,
angles: ArrayLike,
det_count: Optional[int] = None,
dither: bool = True,
):
r"""
Args:
Expand All @@ -68,65 +67,100 @@ def __init__(
corresponds to summing along antidiagonals.
det_count: Number of elements in detector. If ``None``,
defaults to the size of the diagonal of `im_shape`.
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
self.im_shape = im_shape
self.angles = angles

im_shape = np.array(im_shape)

x0 = -(im_shape - 1) / 2
self.nx = np.array(im_shape)
self.x0 = np.array([-1, -1])
self.dx = 2 / self.nx

if det_count is None:
det_count = int(np.ceil(np.linalg.norm(im_shape)))
self.det_shape = (det_count,)

y0 = -det_count / 2

@jax.vmap
def compute_inds(angle: float) -> ArrayLike:
"""Project pixel positions on to a detector at the given
angle, determine which detector element they contribute to.
"""
x = jnp.stack(
jnp.meshgrid(
*(
jnp.arange(shape_i) * step_i + start_i
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape)
),
indexing="ij",
),
axis=-1,
)

# dither
if dither:
key = jax.random.PRNGKey(0)
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5)

# project
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle)

# quantize
inds = jnp.floor((Px - y0)).astype(int)

# map negative inds to y_size, which is out of bounds and will be ignored
# otherwise they index from the end like x[-1]
inds = jnp.where(inds < 0, det_count, inds)

return inds

inds = compute_inds(angles) # (len(angles), *im_shape)

@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
"""Compute the projection at a single angle."""
return jnp.zeros(det_count).at[inds].add(im)

@jax.jit
def project(im: ArrayLike) -> ArrayLike:
"""Compute the projection for all angles."""
return project_inds(im, inds)

self.project = project
self.det_count = det_count
self.ny = det_count

self.y0 = -np.sqrt(2)
self.dy = 2 * np.sqrt(2) / det_count

Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
# scale so dy is 1.0
self.x0 = self.x0 / self.dy
self.dx = self.dx / self.dy
self.y0 = self.y0 / self.dy
self.dy = self.dy / self.dy

def project(self, im):
"""Compute X-ray projection."""
return _project(im, self.x0, self.dx, self.y0, self.ny, self.angles)


@partial(jax.jit, static_argnames=["ny"])
def _project(im, x0, dx, y0, ny, angles):
r"""
Args:
im: Input array, (M, N).
x0: Location of the corner of the pixel im[0,0].
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
y0: Location of the edge of the first detector bin.
ny: Number of detector bins.
angles: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
"""
nx = im.shape
inds, weights = _calc_weights(x0, dx, nx, angles, y0)

y = (
jnp.zeros((len(angles), ny))
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
)

y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))

return y


@partial(jax.jit, static_argnames=["nx", "y0"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(x0, dx, nx, angle, y0):
"""

Args:
x0: Location of the corner of the pixel im[0,0].
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
nx: Input image shape.
angle: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
(The argument is `vmap`ed.)
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
y0: Location of the edge of the first detector bin.
"""
u = [jnp.cos(angle), jnp.sin(angle)]
Px0 = x0[0] * u[0] + x0[1] * u[1] - y0
Pdx = [dx[0] * u[0], dx[1] * u[1]]
Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]]))

Px = (
Pxmin
+ Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1)
+ Pdx[1] * jnp.arange(nx[1]).reshape(1, -1)
)

# detector bin inds
inds = jnp.floor(Px).astype(int)

# weights
Pdx = jnp.array(u) * jnp.array(dx)
diag1 = jnp.abs(Pdx[0] + Pdx[1])
diag2 = jnp.abs(Pdx[0] - Pdx[1])
w = jnp.max(jnp.array([diag1, diag2]))
f = jnp.min(jnp.array([diag1, diag2]))

width = (w + f) / 2

Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
distance_to_next = 1 - (Px - inds) # always in (0, 1]

Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
weights = jnp.minimum(distance_to_next, width) / width

return inds, weights
26 changes: 26 additions & 0 deletions scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.numpy as jnp

import scico
from scico.linop import Parallel2dProjector, XRayTransform


Expand All @@ -24,3 +25,28 @@ def test_apply():
# dither off
H = XRayTransform(Parallel2dProjector(x.shape, angles, dither=False))
y = H @ x


def test_apply_adjoint():
im_shape = (12, 13)
num_angles = 10
x = jnp.ones(im_shape)

angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

# general projection
H = XRayTransform(Parallel2dProjector(x.shape, angles))
y = H @ x
assert y.shape[0] == (num_angles)

# adjoint
bp = H.T @ y
assert scico.linop.valid_adjoint(
H, H.T, eps=1e-5
) # associative reductions might cause small errors, hence 1e-5

# fixed det_length
det_length = 14
H = XRayTransform(Parallel2dProjector(x.shape, angles, det_length=det_length))
y = H @ x
assert y.shape[1] == det_length
Loading