diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index a2dcbbc07..0b05d7edf 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -11,7 +11,7 @@ # isort: off from scico._generic_operators import LinearOperator -from ._linop import Diagonal, Identity, power_iteration, Sum +from ._linop import Diagonal, Identity, power_iteration, Sum, valid_adjoint from ._matrix import MatrixOperator from ._diff import FiniteDifference from ._convolve import Convolve, ConvolveByX @@ -32,6 +32,7 @@ "LinearOperatorStack", "Sum", "power_iteration", + "valid_adjoint", ] # Imported items in __all__ appear to originate in top-level linop module diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 217c52a24..297a88f66 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -22,17 +22,24 @@ from scico.random import randn from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape -__author__ = """Luke Pfister """ +__author__ = """\n""".join( + ["Luke Pfister ", "Brendt Wohlberg "] +) def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None): - """Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator` using power iteration. + """Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`. + + Compute largest eigenvalue of a diagonalizable + :class:`.LinearOperator` using power iteration. Args: - A: :class:`.LinearOperator` used for computation. Must be diagonalizable. - For arbitrary :class:`.LinearOperator`, call this function on ``A.conj().T @ A``. + A: :class:`.LinearOperator` used for computation. Must be + diagonalizable. For arbitrary :class:`.LinearOperator`, call + this function on ``A.conj().T @ A``. maxiter: Maximum number of power iterations to use. Default: 100 - key: Jax PRNG key. Defaults to None, in which case a new key is created. + key: Jax PRNG key. Defaults to None, in which case a new key is + created. Returns: tuple: A tuple (mu, v) containing: @@ -51,8 +58,62 @@ def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey return mu, v +def valid_adjoint( + A: LinearOperator, + AT: LinearOperator, + eps: Optional[float] = 1e-7, + key: Optional[PRNGKey] = None, +) -> Union[bool, float]: + r"""Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`. + + The test exploits the identity + + .. math:: + \mathbf{y}^T (A \mathbf{x}) = (\mathbf{y}^T A) \mathbf{x} = + (A^T \mathbf{y})^T \mathbf{x} + + by computing :math:`\mathbf{u} = A \mathbf{x}` and + :math:`\mathbf{v} = A^T \mathbf{y}` for random :math:`\mathbf{x}` + and :math:`\mathbf{y}` and confirming that :math:`\| \mathbf{y}^T + \mathbf{u} - \mathbf{v}^T \mathbf{x} \|_2 < \epsilon` since + + .. math:: + \mathbf{y}^T \mathbf{u} = \mathbf{y}^T (A \mathbf{x}) = + (A^T \mathbf{y})^T \mathbf{x} = \mathbf{v}^T \mathbf{x} + + when :math:`A^T` is a valid adjoint of :math:`A`. If :math:`A` is a + complex operator (with a complex `input_dtype`) then the test checks + whether `AT` is the Hermitian conjugate of `A`, with a test as above, + but with all the :math:`\cdot^T` replaced with :math:`\cdot^H`. + + Args: + A: Primary :class:`.LinearOperator`. + AT: Adjoint :class:`.LinearOperator`. + eps: Error threshold for validation of `AT` as adjoint of `A`. If + None, the relative error is returned instead of a boolean value. + key: Jax PRNG key. Defaults to None, in which case a new key is + created. + + Returns: + Boolean value indicating that validation passed, or relative error + of test, depending on type of parameter `eps`. + """ + + x0, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype) + x1, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype) + y0 = A(x0) + y1 = AT(x1) + x1y0 = snp.dot(x1.ravel().conj(), y0.ravel()) + y1x0 = snp.dot(y1.ravel().conj(), x0.ravel()) + err = snp.linalg.norm(x1y0 - y1x0) / max(snp.linalg.norm(x1y0), snp.linalg.norm(y1x0)) + if eps is None: + return err + else: + return err < eps + + class Diagonal(LinearOperator): - """Diagonal linear operator""" + """Diagonal linear operator.""" def __init__(self, diagonal: JaxArray, input_dtype: Optional[DType] = None, **kwargs): r""" @@ -124,7 +185,7 @@ def __rmatmul__(self, x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockAr class Sum(LinearOperator): - """A linear operator for summing along an axis or set of axes""" + """A linear operator for summing along an axis or set of axes.""" def __init__( self, @@ -138,12 +199,15 @@ def __init__( Wraps :func:`jax.numpy.sum` as a :class:`.LinearOperator`. Args: - sum_axis: The axis or set of axes to sum over. If `None`, sum is taken over all axes. + sum_axis: The axis or set of axes to sum over. If `None`, + sum is taken over all axes. input_shape: Shape of input array. input_dtype: `dtype` for input argument. - Defaults to `float32`. If this LinearOperator implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + Defaults to `float32`. If this LinearOperator implements + complex-valued operations, this must be `complex64` for + proper adjoint and gradient calculation. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. """ input_ndim = len(input_shape) diff --git a/scico/test/linop/test_circconv.py b/scico/test/linop/test_circconv.py index 9c591fd2f..9f54dbba9 100644 --- a/scico/test/linop/test_circconv.py +++ b/scico/test/linop/test_circconv.py @@ -9,7 +9,7 @@ import scico.numpy as snp from scico.linop import CircularConvolve, Convolve from scico.random import randint, randn, uniform -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test SHAPE_SPECS = [ ((16,), None, (3,)), # 1D @@ -43,7 +43,8 @@ def test_eval(self, axes_shape_spec, input_dtype, jit): Ax = A @ x - # check that a specific pixel of Ax computes an inner product between x and (flipped, padded, shifted) h + # check that a specific pixel of Ax computes an inner product between x and + # (flipped, padded, shifted) h h_flipped = np.flip(h, range(-A.ndims, 0)) # flip only in the spatial dims (not batches) x_inds = (...,) + tuple( @@ -68,8 +69,7 @@ def test_adjoint(self, axes_shape_spec, input_dtype, jit): A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) - adjoint_AtA_test(A, self.key) - adjoint_AAt_test(A, self.key) + adjoint_test(A, self.key) @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) diff --git a/scico/test/linop/test_convolve.py b/scico/test/linop/test_convolve.py index 9bbd8b200..2eaa0e171 100644 --- a/scico/test/linop/test_convolve.py +++ b/scico/test/linop/test_convolve.py @@ -10,7 +10,7 @@ from scico.linop import Convolve, ConvolveByX, LinearOperator from scico.random import randn -from scico.test.linop.test_linop import AbsMatOp, adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import AbsMatOp, adjoint_test class TestConvolve: @@ -46,8 +46,7 @@ def test_adjoint(self, input_shape, mode, jit, input_dtype): A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) - adjoint_AtA_test(A, self.key) - adjoint_AAt_test(A, self.key) + adjoint_test(A, self.key) class ConvolveTestObj: @@ -212,8 +211,7 @@ def test_adjoint(self, input_shape, mode, jit, input_dtype): A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit) - adjoint_AtA_test(A, self.key) - adjoint_AAt_test(A, self.key) + adjoint_test(A, self.key) class ConvolveByXTestObj: diff --git a/scico/test/linop/test_dft.py b/scico/test/linop/test_dft.py index 728bcf26a..e59a58286 100644 --- a/scico/test/linop/test_dft.py +++ b/scico/test/linop/test_dft.py @@ -6,7 +6,7 @@ from scico.linop import DFT from scico.random import randn -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test class TestDFT: @@ -47,8 +47,7 @@ def test_adjoint(self, input_shape, pad_output, jit): output_shape = None F = DFT(input_shape=input_shape, output_shape=output_shape, jit=jit) - adjoint_AtA_test(F, self.key) - adjoint_AAt_test(F, self.key) + adjoint_test(F, self.key) @pytest.mark.parametrize("input_shape", [(32,), (32, 48)]) @pytest.mark.parametrize("pad_output", [True, False]) diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index c05fba36d..168b5ded8 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -6,7 +6,7 @@ from scico.blockarray import BlockArray from scico.linop import FiniteDifference from scico.random import randn -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @@ -58,8 +58,7 @@ def test_adjoint(self, input_shape, input_dtype, axes, jit): A = FiniteDifference( input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit ) - adjoint_AtA_test(A) - adjoint_AAt_test(A) + adjoint_test(A) @pytest.mark.parametrize( diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index 6adb53592..f0b3d53a9 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -18,11 +18,8 @@ from scico.typing import PRNGKey -def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4): - """Check the validity of A.conj().T as the adjoint for a LinearOperator A - - Compares the quantity sum(x.conj() * A.conj().T @ A @ x) against - norm(A @ x)**2. If the adjoint is correct, these quantities should be equal. +def adjoint_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4): + """Check the validity of A.conj().T as the adjoint for a LinearOperator A. Args: A : LinearOperator to test @@ -30,65 +27,15 @@ def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rto rtol: Relative tolerance """ - # Generate a signal in the domain of A - x, key = randn(A.input_shape, dtype=A.input_dtype, key=key) - - Ax = A @ x - - AtAx = A.conj().T @ Ax - num = snp.sum(x.conj() * AtAx) - den = snp.linalg.norm(Ax) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) - - AtAx = A.H @ Ax - num = snp.sum(x.conj() * AtAx) - den = snp.linalg.norm(Ax) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) - - AtAx = A.adj(Ax) - num = snp.sum(x.conj() * AtAx) - den = snp.linalg.norm(Ax) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) - - -def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4): - """Check the validity of A as the adjoint for a LinearOperator A.conj().T - - Compares the quantity sum(y.conj() * A @ A.conj().T @ y) against - norm(A.conj().T @ y)**2. If the adjoint is correct, these quantities should be equal. - - Args: - A : LinearOperator to test - key: PRNGKey for generating `x`. - rtol: Relative tolerance - """ - # Generate a signal in the domain of A^T - y, key = randn(A.output_shape, dtype=A.output_dtype, key=key) - - Aty = A.conj().T @ y - AAty = A @ Aty - num = snp.sum(y.conj() * AAty) - den = snp.linalg.norm(Aty) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) - - Aty = A.H @ y - AAty = A @ Aty - num = snp.sum(y.conj() * AAty) - den = snp.linalg.norm(Aty) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) - - Aty = A.adj(y) - AAty = A @ Aty - num = snp.sum(y.conj() * AAty) - den = snp.linalg.norm(Aty) ** 2 - np.testing.assert_allclose(num / den, 1, rtol=rtol) + assert linop.valid_adjoint(A, A.H, rtol, key) class AbsMatOp(linop.LinearOperator): """Simple LinearOperator subclass for testing purposes. - Similar to linop.MatrixOperator, but does not use the specialized MatrixOperator methods (.T, adj, etc). - Used to verify the LinearOperator interface. + Similar to linop.MatrixOperator, but does not use the specialized + MatrixOperator methods (.T, adj, etc). Used to verify the + LinearOperator interface. """ def __init__(self, A, adj_fn=None): @@ -369,8 +316,7 @@ def test_adjoint(self, input_shape, diagonal_dtype): diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) D = linop.Diagonal(diagonal=diagonal) - adjoint_AtA_test(D) - adjoint_AAt_test(D) + adjoint_test(D) @pytest.mark.parametrize("operator", [op.add, op.sub]) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) @@ -544,8 +490,7 @@ def test_sum_eval(sumtestobj, axis): def test_sum_adj(sumtestobj, axis): x = sumtestobj.x A = linop.Sum(input_shape=x.shape, input_dtype=x.dtype, sum_axis=axis) - adjoint_AtA_test(A) - adjoint_AAt_test(A) + adjoint_test(A) @pytest.mark.parametrize("axis", (5, (1, 1), (0, 1, 2, 3, 4))) diff --git a/scico/test/linop/test_optics.py b/scico/test/linop/test_optics.py index 4009dbdbc..dbd7293e4 100644 --- a/scico/test/linop/test_optics.py +++ b/scico/test/linop/test_optics.py @@ -11,7 +11,7 @@ radial_transverse_frequency, ) from scico.random import randn -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test prop_list = [AngularSpectrumPropagator, FresnelPropagator, FraunhoferPropagator] @@ -29,8 +29,7 @@ def setup_method(self, method): @pytest.mark.parametrize("prop", prop_list) def test_prop_adjoint(self, prop, ndim): A = prop(input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z) - adjoint_AtA_test(A, self.key) - adjoint_AAt_test(A, self.key) + adjoint_test(A, self.key) @pytest.mark.parametrize("ndim", [1, 2]) def test_AS_inverse(self, ndim): diff --git a/scico/test/linop/test_radon_astra.py b/scico/test/linop/test_radon_astra.py index eac69d670..efebf5c45 100644 --- a/scico/test/linop/test_radon_astra.py +++ b/scico/test/linop/test_radon_astra.py @@ -4,7 +4,7 @@ import pytest -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test try: from scico.linop.radon_astra import ParallelBeamProjector @@ -95,5 +95,4 @@ def test_adjoint_grad(testobj): def test_adjoint(testobj): A = testobj.A - adjoint_AAt_test(A, rtol=get_tol()) - adjoint_AtA_test(A, rtol=get_tol()) + adjoint_test(A, rtol=get_tol()) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 28f68bea2..0e8812f24 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -6,7 +6,7 @@ import scico import scico.numpy as snp -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test from scico.test.test_functional import prox_test try: @@ -61,8 +61,7 @@ def test_adjoint(Nx, Ny, num_angles, num_channels, is_3d): im = make_im(Nx, Ny, is_3d) A = make_A(im, num_angles, num_channels) - adjoint_AtA_test(A) - adjoint_AAt_test(A) + adjoint_test(A) @pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (SMALL_INPUT,)) diff --git a/scico/test/linop/test_stack.py b/scico/test/linop/test_stack.py index 1d2fb0b9d..92cb22579 100644 --- a/scico/test/linop/test_stack.py +++ b/scico/test/linop/test_stack.py @@ -5,7 +5,7 @@ import pytest from scico.linop import Convolve, Identity, LinearOperatorStack -from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test +from scico.test.linop.test_linop import adjoint_test class TestLinearOperatorStack: @@ -64,15 +64,13 @@ def test_adjoint(self, collapse, jit): A = Convolve(jax.device_put(np.ones((3, 3))), (9, 15)) B = Convolve(jax.device_put(np.ones((2, 2))), (9, 15)) H = LinearOperatorStack([A, B], collapse=collapse, jit=jit) - adjoint_AtA_test(H, self.key) - adjoint_AAt_test(H, self.key) + adjoint_test(H, self.key) # collapsable case A = Convolve(jax.device_put(np.ones((2, 2))), (9, 15)) B = Convolve(jax.device_put(np.ones((2, 2))), (9, 15)) H = LinearOperatorStack([A, B], collapse=collapse, jit=jit) - adjoint_AtA_test(H, self.key) - adjoint_AAt_test(H, self.key) + adjoint_test(H, self.key) @pytest.mark.parametrize("collapse", [False, True]) @pytest.mark.parametrize("jit", [False, True]) diff --git a/scico/util.py b/scico/util.py index cff495028..645d735b0 100644 --- a/scico/util.py +++ b/scico/util.py @@ -44,7 +44,10 @@ def device_info(id: int = 0) -> str: # pragma: no cover """Get a string describing the specified device. Args: - id: ID number of device + id: ID number of device. + + Returns: + Device description string. """ numdev = jax.device_count() if id >= numdev: