diff --git a/CHANGES.rst b/CHANGES.rst index 36dc57db3..ae2435033 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,12 +7,15 @@ Version 0.0.2 (unreleased) ---------------------------- • Additional optimization algorithms: Linearized ADMM and PDHG. +• Additional Abel transform and array slicing linear operators. +• Additional nuclear norm functional. +• New module ``scico.ray.tune`` providing a simplified interface to Ray Tune. • Move optimization algorithms into ``optimize`` subpackage. • Additional iteration stats columns for iterative ADMM subproblem solvers. • Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats. • Move some functions from ``util`` and ``math`` modules to new ``array`` module. -• Bump pinned `jaxlib` and `jax` versions to 0.1.70 and 0.2.19 respectively. +• Bump pinned `jaxlib` and `jax` versions to 0.1.75 and 0.2.26 respectively. Version 0.0.1 (2021-11-24) diff --git a/data b/data index a5550f049..b32228be9 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit a5550f04949e05486f9b4ab54114ae3a7f50edb7 +Subproject commit b32228be976b0d82d1473def1ea0b15353a98c20 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index fae2d5bcd..12790a093 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -22,6 +22,7 @@ Computed Tomography .. toctree:: :maxdepth: 1 + examples/ct_abel_tv_admm examples/ct_astra_pcg examples/ct_astra_tv_admm examples/ct_astra_weighted_tv_admm @@ -97,6 +98,7 @@ Total Variation .. toctree:: :maxdepth: 1 + examples/ct_abel_tv_admm examples/ct_astra_tv_admm examples/ct_astra_weighted_tv_admm examples/ct_svmbir_tv_multi @@ -135,6 +137,7 @@ ADMM .. toctree:: :maxdepth: 1 + examples/ct_abel_tv_admm examples/ct_astra_tv_admm examples/ct_astra_weighted_tv_admm examples/ct_svmbir_ppp_bm3d_admm_cg diff --git a/docs/source/references.bib b/docs/source/references.bib index 91fd7cbc3..879201bae 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -173,7 +173,7 @@ @Article {esser-2010-general Primal-Dual Algorithms for Convex Optimization in Imaging Science}, journal = {SIAM Journal on Imaging Sciences}, - doi = {10.1137/09076934x}, + doi = {10.1137/09076934x}, year = 2010, month = Jan, volume = 3, @@ -369,6 +369,15 @@ @Misc {svmbir-2020 year = 2020 } +@Misc {pyabel-2022, + author = {Stephen Gibson and Daniel Hickstein and Roman Yurchak, + Mikhail Ryazanov and Dhrubajyoti Das and Gilbert Shih}, + title = {PyAbel}, + howpublished = {PyAbel/PyAbel: v0.8.5}, + year = 2022, + doi = {10.5281/zenodo.5888391} +} + @InProceedings {venkatakrishnan-2013-plugandplay2, author = {Singanallur V. Venkatakrishnan and Charles A. Bouman and Brendt Wohlberg}, @@ -410,7 +419,7 @@ @article{yang-2012-linearized title = {Linearized augmented {L}agrangian and alternating direction methods for nuclear norm minimization}, journal = {Mathematics of Computation}, - doi = {10.1090/s0025-5718-2012-02598-1}, + doi = {10.1090/s0025-5718-2012-02598-1}, year = 2012, month = mar, volume = 82, diff --git a/docs/source/team.rst b/docs/source/team.rst index bcdea5ee0..574153a52 100644 --- a/docs/source/team.rst +++ b/docs/source/team.rst @@ -31,6 +31,7 @@ Core Developers - `Thilo Balke `_ - `Fernando Davis `_ - `Cristina Garcia Cardona `_ +- `Soumendu Majee `_ - `Michael McCann `_ - `Brendt Wohlberg `_ @@ -40,5 +41,4 @@ Contributors - `Oleg Korobkin `_ (BlockArray improvements) - `Yanpeng Yuan `_ (ASTRA interface improvements) -- `Soumendu Majee `_ (SVMBIR interface improvements) - `Saurav Maheshkar `_ (Improvements to pre-commit configuration) diff --git a/examples/scripts/ct_abel_tv_admm.py b/examples/scripts/ct_abel_tv_admm.py new file mode 100644 index 000000000..3146d4fd2 --- /dev/null +++ b/examples/scripts/ct_abel_tv_admm.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Regularized Abel Inversion +========================== + +This example demonstrates a TV-regularized Abel inversion using +an Abel projector based on PyAbel :cite:`pyabel-2022` +""" + +import numpy as np + +import scico.numpy as snp +from scico import functional, linop, loss, plot +from scico.examples import create_circular_phantom +from scico.linop.abel import AbelProjector +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.util import device_info + +""" +Create a ground truth image. +""" +x_gt = create_circular_phantom((256, 254), [100, 50, 25], [1, 0, 0.5]) + +""" +Set up the forward operator and create a test measurement +""" +A = AbelProjector(x_gt.shape) +y = A @ x_gt +y = y + np.random.normal(size=y.shape).astype(np.float32) +ATy = A.T @ y + + +""" +Set up ADMM solver object. +""" +λ = 1.71e01 # L1 norm regularization parameter +ρ = 4.83e01 # ADMM penalty parameter +maxiter = 100 # number of ADMM iterations +cg_tol = 1e-4 # CG relative tolerance +cg_maxiter = 25 # maximum CG iterations per ADMM iteration + +g = λ * functional.L1Norm() +C = linop.FiniteDifference(input_shape=x_gt.shape) + +f = loss.SquaredL2Loss(y=y, A=A) + +x_inv = A.inverse(y) +x0 = snp.clip(x_inv, 0, 1.0) + +solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[ρ], + x0=x0, + maxiter=maxiter, + subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), + itstat_options={"display": True, "period": 5}, +) + + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +solver.solve() +hist = solver.itstat_object.history(transpose=True) +x_tv = snp.clip(solver.x, 0, 1.0) + +norm = plot.matplotlib.colors.Normalize(vmin=-0.1, vmax=1.2) +fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(12, 12)) +plot.imview(x_gt, title="Ground Truth", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 0], norm=norm) +plot.imview(y, title="Measurements", cmap=plot.cm.Blues, fig=fig, ax=ax[0, 1]) +plot.imview(x_inv, title="Inverse Abel", cmap=plot.cm.Blues, fig=fig, ax=ax[1, 0], norm=norm) +plot.imview( + x_tv, title="TV Regularized Inversion", cmap=plot.cm.Blues, fig=fig, ax=ax[1, 1], norm=norm +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/ct_abel_tv_admm_tune.py b/examples/scripts/ct_abel_tv_admm_tune.py new file mode 100644 index 000000000..6e800925e --- /dev/null +++ b/examples/scripts/ct_abel_tv_admm_tune.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Regularized Abel Inversion Tuning Demo +========================== + +This example demonstrates the use of +[scico.ray.tune](../_autosummary/scico.ray.tune.rst) to tune +parameters for the companion [example script](ct_abel_tv_admm.rst). +""" + +import numpy as np + +import jax + +import scico.ray as ray +from scico import functional, linop, loss, metric, plot +from scico.examples import create_circular_phantom +from scico.linop.abel import AbelProjector +from scico.optimize.admm import ADMM, LinearSubproblemSolver +from scico.ray import tune + +""" +Create a ground truth image. +""" +x_gt = create_circular_phantom((256, 256), [100, 50, 25], [1, 0, 0.5]) + + +""" +Set up the forward operator and create a test measurement +""" +A = AbelProjector(x_gt.shape) +y = A @ x_gt +y = y + 1 * np.random.normal(size=y.shape) +ATy = A.T @ y + +""" +Put main arrays into ray object store. +""" +ray_x_gt, ray_y = ray.put(np.array(x_gt)), ray.put(np.array(y)) + + +""" +Define performance evaluation function. +""" + + +def eval_params(config): + # Extract solver parameters from config dict. + λ, ρ = config["lambda"], config["rho"] + # Get main arrays from ray object store. + x_gt, y = ray.get([ray_x_gt, ray_y]) + # Put main arrays on jax device. + x_gt, y = jax.device_put([x_gt, y]) + # Set up problem to be solved. + A = AbelProjector(x_gt.shape) + f = loss.SquaredL2Loss(y=y, A=A) + g = λ * functional.L1Norm() + C = linop.FiniteDifference(input_shape=x_gt.shape) + # Define solver. + solver = ADMM( + f=f, + g_list=[g], + C_list=[C], + rho_list=[ρ], + x0=A.inverse(y), + maxiter=10, + subproblem_solver=LinearSubproblemSolver(), + ) + # Perform 50 iterations, reporting performance to ray.tune every 10 iterations. + for step in range(10): + x_admm = solver.solve() + tune.report(psnr=float(metric.psnr(x_gt, x_admm))) + + +""" +Define parameter search space and resources per trial. +""" +config = {"lambda": tune.loguniform(1e-2, 1e3), "rho": tune.loguniform(1e-1, 1e3)} +resources = {"gpu": 0, "cpu": 1} # gpus per trial, cpus per trial + + +""" +Run parameter search. +""" +analysis = tune.run( + eval_params, + metric="psnr", + mode="max", + num_samples=100, + config=config, + resources_per_trial=resources, + hyperopt=True, + verbose=True, +) + +""" +Display best parameters and corresponding performance. +""" +best_config = analysis.get_best_config(metric="psnr", mode="max") +print(f"Best PSNR: {analysis.get_best_trial().last_result['psnr']:.2f} dB") +print("Best config: " + ", ".join([f"{k}: {v:.2e}" for k, v in best_config.items()])) + + +""" +Plot parameter values visited during parameter search. Marker sizes are +proportional to number of iterations run at each parameter pair. The best +point in the parameter space is indicated in red. +""" +fig = plot.figure(figsize=(8, 8)) +for t in analysis.trials: + n = t.metric_analysis["training_iteration"]["max"] + plot.plot( + t.config["lambda"], + t.config["rho"], + ptyp="loglog", + lw=0, + ms=(0.5 + 1.5 * n), + marker="o", + mfc="blue", + mec="blue", + fig=fig, + ) +plot.plot( + best_config["lambda"], + best_config["rho"], + ptyp="loglog", + title="Parameter search sampling locations\n(marker size proportional to number of iterations)", + xlbl=r"$\rho$", + ylbl=r"$\lambda$", + lw=0, + ms=5.0, + marker="o", + mfc="red", + mec="red", + fig=fig, +) +ax = fig.axes[0] +ax.set_xlim([config["rho"].lower, config["rho"].upper]) +ax.set_ylim([config["lambda"].lower, config["lambda"].upper]) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index f096f4800..c20450d2a 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -9,6 +9,7 @@ Organized by Application Computed Tomography ^^^^^^^^^^^^^^^^^^^ + - ct_abel_tv_admm.py - ct_astra_pcg.py - ct_astra_tv_admm.py - ct_astra_weighted_tv_admm.py @@ -66,6 +67,7 @@ Plug and Play Priors Total Variation ^^^^^^^^^^^^^^^ + - ct_abel_tv_admm.py - ct_astra_tv_admm.py - ct_astra_weighted_tv_admm.py - ct_svmbir_tv_multi.py diff --git a/requirements.txt b/requirements.txt index cbe35a3fb..c987e034c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ jax==0.2.26 flax bm3d svmbir>=0.2.7 +pyabel>=0.8.5 diff --git a/scico/examples.py b/scico/examples.py index eb40db9c1..cb661803d 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,6 +12,7 @@ import os import tempfile import zipfile +from typing import Optional import numpy as np @@ -19,7 +20,7 @@ import scico.numpy as snp from scico import util -from scico.typing import JaxArray +from scico.typing import JaxArray, Shape from scipy.ndimage import zoom @@ -239,3 +240,51 @@ def tile_volume_slices(x: JaxArray, sep_width: int = 10) -> JaxArray: out = snp.where(snp.isnan(out), snp.nanmax(out), out) return out + + +def create_cone(img_shape: Shape, center: Optional[list] = None): + """Compute a 2D map of the distance from a center pixel. + + Args: + img_shape : Shape of the image for which the distance map is being computed. + center : Tuple of center pixel ids. If None, this is set to the center of the image + + Returns: + An image containing a 2D map of the distances + """ + + if center == None: + center = [img_dim // 2 for img_dim in img_shape] + + coords = [snp.arange(0, img_dim) for img_dim in img_shape] + coord_mesh = snp.meshgrid(*coords, sparse=True, indexing="ij") + + dist_map = sum([(coord_mesh[i] - center[i]) ** 2 for i in range(len(coord_mesh))]) + dist_map = snp.sqrt(dist_map) + + return dist_map + + +def create_circular_phantom( + img_shape: Shape, radius_list: list, val_list: list, center: Optional[list] = None +): + """Construct a circular phantom with given radii and intensities + + Args: + img_shape : Shape of the phontom to be created + radius_list : List of radii of the rings in the phantom + val_list : List of intensity values of the rings in the phantom + center : Tuple of center pixel ids. If None, this is set to the center of the image + + Returns: + The computed circular phantom + """ + + dist_map = create_cone(img_shape, center) + + img = snp.zeros(img_shape) + for r, val in zip(radius_list, val_list): + # img[dist_map < r] = val + img = img.at[dist_map < r].set(val) + + return img diff --git a/scico/linop/abel.py b/scico/linop/abel.py new file mode 100644 index 000000000..cd0119f27 --- /dev/null +++ b/scico/linop/abel.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Abel transform LinearOperator wrapping the pyabel package. + +Abel transform LinearOperator wrapping the +`pyabel `_ package. +""" + +import math +from typing import Optional + +import numpy as np + +import jax +import jax.numpy as jnp +import jax.numpy.fft as jnfft + +import abel + +from scico.linop import LinearOperator +from scico.typing import JaxArray, Shape +from scipy.linalg import solve_triangular + + +class AbelProjector(LinearOperator): + r"""Abel transform projector based on `PyAbel `_. + + Perform Abel transform (parallel beam tomographic projection of + cylindrically symmetric objects) for a 2D image. The input 2D image + is assumed to be centered and left-right symmetric. + """ + + def __init__(self, img_shape: Shape): + """ + Args: + img_shape: Shape of the input image. + """ + self.proj_mat_quad = _pyabel_daun_get_proj_matrix(img_shape) + + super().__init__( + input_shape=img_shape, + output_shape=img_shape, + input_dtype=np.float32, + output_dtype=np.float32, + adj_fn=self._adj, + jit=True, + ) + + def _eval(self, x: JaxArray) -> JaxArray: + return _pyabel_transform(x, direction="forward", proj_mat_quad=self.proj_mat_quad).astype( + self.output_dtype + ) + + def _adj(self, x: JaxArray) -> JaxArray: + return _pyabel_transform(x, direction="transpose", proj_mat_quad=self.proj_mat_quad).astype( + self.input_dtype + ) + + def inverse(self, y: JaxArray) -> JaxArray: + """Perform inverse Abel transform. + + Args: + y: Input image (assumed to be a result of an Abel transform) + + Returns: + Output of inverse Abel transform + """ + return _pyabel_transform(y, direction="inverse", proj_mat_quad=self.proj_mat_quad).astype( + self.input_dtype + ) + + +def _pyabel_transform( + x: JaxArray, direction: str, proj_mat_quad: JaxArray, symmetry_axis: Optional[list] = None +) -> JaxArray: + """Perform Abel transformations (forward, inverse and transposed). + + This function contains code copied from `PyAbel `_. + """ + + if symmetry_axis is None: + symmetry_axis = [None] + + Q0, Q1, Q2, Q3 = get_image_quadrants( + x, symmetry_axis=symmetry_axis, use_quadrants=(True, True, True, True) + ) + + def transform_quad(data): + if direction == "forward": + return data.dot(proj_mat_quad) + elif direction == "transpose": + return data.dot(proj_mat_quad.T) + elif direction == "inverse": + return solve_triangular(proj_mat_quad.T, data.T).T + else: + ValueError("Unsupported direction") + + AQ0 = AQ1 = AQ2 = AQ3 = None + AQ1 = transform_quad(Q1) + + if 1 not in symmetry_axis: + AQ2 = transform_quad(Q2) + + if 0 not in symmetry_axis: + AQ0 = transform_quad(Q0) + + if None in symmetry_axis: + AQ3 = transform_quad(Q3) + + return put_image_quadrants( + (AQ0, AQ1, AQ2, AQ3), original_image_shape=x.shape, symmetry_axis=symmetry_axis + ) + + +def _pyabel_daun_get_proj_matrix(img_shape: Shape) -> JaxArray: + """Get single-quadrant projection matrix.""" + proj_matrix = abel.daun.get_bs_cached( + math.ceil(img_shape[1] / 2), + degree=0, + reg_type=None, + strength=0, + direction="forward", + verbose=False, + ) + return jax.device_put(proj_matrix) + + +# Read abel.tools.symmetry module into a string. +mod_file = abel.tools.symmetry.__file__ +with open(mod_file, "r") as f: + mod_str = f.read() + +# Replace numpy functions that touch the main arrays with corresponding jax.numpy functions +mod_str = mod_str.replace("fftpack.", "jnfft.") +mod_str = mod_str.replace("np.atleast_2d", "jnp.atleast_2d") +mod_str = mod_str.replace("np.flip", "jnp.flip") +mod_str = mod_str.replace("np.concat", "jnp.concat") + +# Exec the module extract defined functions from the exec scope +scope = {"jnp": jnp, "jnfft": jnfft} +exec(mod_str, scope) +get_image_quadrants = scope["get_image_quadrants"] +put_image_quadrants = scope["put_image_quadrants"] diff --git a/scico/test/linop/test_abel.py b/scico/test/linop/test_abel.py new file mode 100644 index 000000000..7a2a57d45 --- /dev/null +++ b/scico/test/linop/test_abel.py @@ -0,0 +1,65 @@ +import numpy as np + +import jax + +import pytest + +import scico.numpy as snp +from scico.linop.abel import AbelProjector +from scico.test.linop.test_linop import adjoint_test + +BIG_INPUT = (128, 128) +SMALL_INPUT = (4, 5) + + +def make_im(Nx, Ny): + x, y = snp.meshgrid(snp.linspace(-1, 1, Nx), snp.linspace(-1, 1, Ny)) + + im = snp.where(x ** 2 + y ** 2 < 0.3, 1.0, 0.0) + + return im + + +@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) +def test_inverse(Nx, Ny): + im = make_im(Nx, Ny) + A = AbelProjector(im.shape) + + Ax = A @ im + im_hat = A.inverse(Ax) + np.testing.assert_allclose(im_hat, im, rtol=5e-5) + + +@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) +def test_adjoint(Nx, Ny): + im = make_im(Nx, Ny) + A = AbelProjector(im.shape) + adjoint_test(A) + + +@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) +def test_ATA(Nx, Ny): + x = make_im(Nx, Ny) + A = AbelProjector(x.shape) + Ax = A(x) + ATAx = A.adj(Ax) + np.testing.assert_allclose(np.sum(x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5) + + +@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) +def test_grad(Nx, Ny): + # ensure that we can take grad on a function using our projector + # grad || A(x) ||_2^2 == 2 A.T @ A x + x = make_im(Nx, Ny) + A = AbelProjector(x.shape) + g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2 + np.testing.assert_allclose(jax.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5) + + +@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT)) +def test_adjoint_grad(Nx, Ny): + x = make_im(Nx, Ny) + A = AbelProjector(x.shape) + Ax = A @ x + f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2 + np.testing.assert_allclose(jax.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5) diff --git a/scico/test/test_examples.py b/scico/test/test_examples.py index 9730c6e77..d3bbb6251 100644 --- a/scico/test/test_examples.py +++ b/scico/test/test_examples.py @@ -4,8 +4,11 @@ import numpy as np import imageio +import pytest from scico.examples import ( + create_circular_phantom, + create_cone, downsample_volume, epfl_deconv_data, rgb2gray, @@ -60,3 +63,28 @@ def test_tile_volume_slices(): v = np.ones((16, 16, 16, 3)) tvs = tile_volume_slices(v) assert tvs.ndim == 3 and tvs.shape[-1] == 3 + + +def test_create_circular_phantom(): + img_shape = (32, 32) + radius_list = [2, 4, 8] + val_list = [2, 4, 8] + x_gt = create_circular_phantom(img_shape, radius_list, val_list) + + assert x_gt.shape == img_shape + assert np.max(x_gt) == max(val_list) + assert np.min(x_gt) == 0 + + +@pytest.mark.parametrize( + "img_shape", + ( + (3, 3), + (40, 40), + (100, 100), + (3, 3, 3), + ), +) +def test_create_cone(img_shape): + x_gt = create_cone(img_shape) + assert x_gt.shape == img_shape