Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael McCann authored and Michael-T-McCann committed Jul 28, 2023
1 parent 18e3918 commit 842a402
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,44 @@
# with the package.


import time
r"""
X-ray Projector Comparison
==========================
This example compares SCICO's native X-ray projection algorithm
to that of the ASTRA Toolbox.
"""

import jax
import jax.numpy as jnp
from xdesign import Foam, discrete_phantom

from scico import plot
from scico.linop import XRayProject, ParallelFixedAxis2dProjector
from scico.linop.radon_astra import TomographicProjector
from scico.util import Timer

"""
Create a ground truth image.
"""

N = 512
num_angles = 512


det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt)

"""
Time projector instantiation.
"""

num_angles = 512
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

method_names = ["scico", "astra"]
timer = Timer(
[n + "_init" for n in method_names]
+ [n + "_first_proj" for n in method_names]
+ [n + "_avg_proj" for n in method_names]
)

timer = Timer()

projectors = {}
timer.start("scico_init")
Expand All @@ -44,42 +55,52 @@
)
timer.stop("astra_init")

"""
Time first projector application, which might include JIT overhead.
"""

ys = {}
for name, H in projectors.items():
timer_label = f"{name}_first_proj"
timer.start(timer_label)
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)


"""
Compute average time for a projector application.
"""

num_repeats = 3
for name, H in projectors.items():
timer_label = f"{name}_avg_proj"
timer.start(timer_label)
for _ in range(num_repeats):
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
timer.stop(timer_label)
timer.td[timer_label] /= num_repeats

"""
Display timing results.
On our server, using the GPU:
Using the CPU:
"""

print(timer)

"""
with way 2:
Label Accum. Current
-------------------------------------------
astra_avg_proj 7.30e-01 s Stopped
astra_first_proj 7.41e-01 s Stopped
astra_init 4.63e-03 s Stopped
scico_avg_proj 9.96e-01 s Stopped
scico_first_proj 9.98e-01 s Stopped
scico_init 8.02e+00 s Stopped
Show projections.
"""

fig, ax = plt.subplots()
ax.imshow(ys["scico"])
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 5))
plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1])
fig.show()

fig, ax = plt.subplots()
ax.imshow(ys["astra"])
fig.show()
input("\nWaiting for input to close figures and exit")
51 changes: 10 additions & 41 deletions scico/linop/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,49 +87,18 @@ def compute_inds(angle: float) -> ArrayLike:

return inds

inds = compute_inds(angles)
inds = compute_inds(angles) # (len(angles), *im_shape)

@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
return jnp.zeros(det_length).at[inds].add(im)
# @partial(jax.vmap, in_axes=(None, 0))
# def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
# return jnp.zeros(det_length).at[inds].add(im)

# @jax.jit
# def project(im: ArrayLike) -> ArrayLike:
# return project_inds(im, inds)

@jax.jit
def project(im: ArrayLike) -> ArrayLike:
return project_inds(im, inds)
def project(im):
return jnp.zeros((len(angles), det_length)).at[inds].add(im)

self.project = project


# num_angles = 127

# x = jnp.ones((128, 129))


# H = ParallelFixedAxis2dProjector(x.shape, angles)
# y1 = H.project(x)

# import matplotlib.pyplot as plt

# fig, ax = plt.subplots()
# ax.imshow(y)
# fig.show()

# f = lambda x: H.project(x)[65, 90]
# grad_f = jax.grad(f)

# fig, ax = plt.subplots()
# ax.imshow(grad_f(x))
# fig.show()


# ## back project


# bad_angle = jnp.array([jnp.pi / 4])
# H = ParallelFixedAxis2dProjector(x.shape, bad_angle)
# y = H.project(x)


# fig, ax = plt.subplots()
# ax.plot(y[0])
# fig.show()

0 comments on commit 842a402

Please sign in to comment.