Skip to content

Commit

Permalink
Add 2D support to radon_svmbir
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Oct 7, 2021
1 parent 50a4866 commit 79449a3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 52 deletions.
54 changes: 41 additions & 13 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,22 @@ def __init__(
angles: Array of projeciton angles in radians, should be increasing.
num_channels: Number of pixels in the sinogram
"""
self.input_shape = input_shape
self.angles = angles
self.num_channels = num_channels

if len(input_shape) == 2: # 2D input
self.svmbir_input_shape = (1,) + input_shape
output_shape = (len(angles), num_channels)
self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2]
elif len(input_shape) == 3: # 3D input
self.svmbir_input_shape = input_shape
output_shape = (len(angles), input_shape[0], num_channels)
self.svmbir_output_shape = output_shape
else:
raise ValueError(
f"Only 2D and 3D inputs are supported, but input_shape was {input_shape}"
)

# set up custom_vjp for _eval and _adj so jax.grad works on them
self._eval = jax.custom_vjp(lambda x: self._proj_hcb(x))
self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),))
Expand All @@ -58,8 +70,8 @@ def __init__(
self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),))

super().__init__(
input_shape=self.input_shape,
output_shape=(len(angles), input_shape[0], num_channels),
input_shape=input_shape,
output_shape=output_shape,
input_dtype=np.float32,
output_dtype=np.float32,
adj_fn=self._adj,
Expand All @@ -71,24 +83,30 @@ def _proj(x: JaxArray, angles: JaxArray, num_channels: int) -> JaxArray:
return svmbir.project(np.array(x), np.array(angles), num_channels, verbose=0)

def _proj_hcb(self, x):
x = x.reshape(self.svmbir_input_shape)
# host callback wrapper for _proj
return jax.experimental.host_callback.call(
y = jax.experimental.host_callback.call(
lambda x: self._proj(x, self.angles, self.num_channels),
x,
result_shape=jax.ShapeDtypeStruct(self.output_shape, self.output_dtype),
result_shape=jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype),
)
return y.reshape(self.output_shape)

@staticmethod
def _bproj(y: JaxArray, angles: JaxArray, num_rows: int, num_cols: int):
return svmbir.backproject(np.array(y), np.array(angles), num_rows, num_cols, verbose=0)

def _bproj_hcb(self, y):
y = y.reshape(self.svmbir_output_shape)
# host callback wrapper for _bproj
return jax.experimental.host_callback.call(
lambda y: self._bproj(y, self.angles, self.input_shape[1], self.input_shape[2]),
x = jax.experimental.host_callback.call(
lambda y: self._bproj(
y, self.angles, self.svmbir_input_shape[1], self.svmbir_input_shape[2]
),
y,
result_shape=jax.ShapeDtypeStruct(self.input_shape, self.input_dtype),
result_shape=jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype),
)
return x.reshape(self.input_shape)


class SVMBIRWeightedSquaredL2Loss(WeightedSquaredL2Loss):
Expand All @@ -113,17 +131,27 @@ def __init__(self, *args, **kwargs):
self.has_prox = True

def prox(self, v: JaxArray, lam: float) -> JaxArray:
v = v.reshape(self.A.svmbir_input_shape)
y = self.y.reshape(self.A.svmbir_output_shape)
weights = self.weights.reshape(self.A.svmbir_output_shape)
sigma_p = snp.sqrt(lam)
result = svmbir.recon(
np.array(self.y),
np.array(y),
np.array(self.A.angles),
weights=np.array(self.weights),
weights=np.array(weights),
prox_image=np.array(v),
num_rows=self.A.input_shape[1],
num_cols=self.A.input_shape[2],
num_rows=self.A.svmbir_input_shape[1],
num_cols=self.A.svmbir_input_shape[2],
sigma_p=np.float(sigma_p),
sigma_y=1.0,
positivity=False,
verbose=0,
)
return result
return result.reshape(self.A.input_shape)


def _unsqueeze(x: JaxArray, input_shape: Shape) -> JaxArray:
"""If x is 2D, make it 3D according to SVMBIR's convention"""
if len(input_shape) == 2:
x = x[snp.newaxis, :, :]
return x
69 changes: 30 additions & 39 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,34 @@
pytest.skip("svmbir not installed", allow_module_level=True)


def make_im(Nx, Ny):
BIG_INPUT = (128, 129, 200, 201)
SMALL_INPUT = (4, 5, 7, 8)


def make_im(Nx, Ny, is_3d=True):
x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny))

im = snp.where((x - 0.25) ** 2 / 3 + y ** 2 < 0.1, 1.0, 0.0)
im = im[snp.newaxis, :, :]
if is_3d:
im = im[snp.newaxis, :, :]
im = im.astype(snp.float32)

return im


@pytest.fixture
def im():
# make everything different sizes to catch dimension swap bugs
Nx, Ny = 128, 129

return make_im(Nx, Ny)


@pytest.fixture
def im_small():
# make everything different sizes to catch dimension swap bugs
Nx, Ny = 4, 5

return make_im(Nx, Ny)


def make_A(im, num_angles, num_channels):
angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
A = ParallelBeamProjector(im.shape, angles, num_channels)

return A


@pytest.fixture
def A(im):
num_angles = 200
num_channels = 201
@pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (BIG_INPUT,))
@pytest.mark.parametrize("is_3d", (True, False))
def test_grad(Nx, Ny, num_angles, num_channels, is_3d):
im = make_im(Nx, Ny, is_3d)
A = make_A(im, num_angles, num_channels)

return make_A(im, num_angles, num_channels)


@pytest.fixture
def A_small(im_small):
num_angles = 7
num_channels = 8

return make_A(im_small, num_angles, num_channels)


def test_grad(A, im):
def f(im):
return snp.sum(A._eval(im) ** 2)

Expand All @@ -77,22 +55,35 @@ def f(im):
np.testing.assert_allclose(val_1, val_2)


def test_adjoint(A):
@pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (BIG_INPUT,))
@pytest.mark.parametrize("is_3d", (True, False))
def test_adjoint(Nx, Ny, num_angles, num_channels, is_3d):
im = make_im(Nx, Ny, is_3d)
A = make_A(im, num_angles, num_channels)

adjoint_AtA_test(A)
adjoint_AAt_test(A)


def test_prox(im_small, A_small):
A, im = A_small, im_small
@pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (SMALL_INPUT,))
@pytest.mark.parametrize("is_3d", (True, False))
def test_prox(Nx, Ny, num_angles, num_channels, is_3d):
im = make_im(Nx, Ny, is_3d)
A = make_A(im, num_angles, num_channels)

sino = A @ im

v, _ = scico.random.normal(im.shape, dtype=im.dtype)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A)
prox_test(v, f, f.prox, alpha=0.25)


def test_prox_weights(im_small, A_small):
A, im = A_small, im_small
@pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (SMALL_INPUT,))
@pytest.mark.parametrize("is_3d", (True, False))
def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d):
im = make_im(Nx, Ny, is_3d)
A = make_A(im, num_angles, num_channels)

sino = A @ im

v, _ = scico.random.normal(im.shape, dtype=im.dtype)
Expand Down

0 comments on commit 79449a3

Please sign in to comment.