Skip to content

Commit

Permalink
Improve LinearOperator adjoint tests (#72)
Browse files Browse the repository at this point in the history
* Docstring cleanup

* Docstring fix

* Add adjoint validation function

* Add adjoint validation function

* Fix for complex values

* Update docstring

* Replace adjoint tests

* Switch to new adjoint test function in other linop module tests
  • Loading branch information
bwohlberg authored Nov 1, 2021
1 parent d86b51f commit 656f59d
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 105 deletions.
3 changes: 2 additions & 1 deletion scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# isort: off
from scico._generic_operators import LinearOperator
from ._linop import Diagonal, Identity, power_iteration, Sum
from ._linop import Diagonal, Identity, power_iteration, Sum, valid_adjoint
from ._matrix import MatrixOperator
from ._diff import FiniteDifference
from ._convolve import Convolve, ConvolveByX
Expand All @@ -32,6 +32,7 @@
"LinearOperatorStack",
"Sum",
"power_iteration",
"valid_adjoint",
]

# Imported items in __all__ appear to originate in top-level linop module
Expand Down
86 changes: 75 additions & 11 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,24 @@
from scico.random import randn
from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape

__author__ = """Luke Pfister <[email protected]>"""
__author__ = """\n""".join(
["Luke Pfister <[email protected]>", "Brendt Wohlberg <[email protected]>"]
)


def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None):
"""Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator` using power iteration.
"""Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`.
Compute largest eigenvalue of a diagonalizable
:class:`.LinearOperator` using power iteration.
Args:
A: :class:`.LinearOperator` used for computation. Must be diagonalizable.
For arbitrary :class:`.LinearOperator`, call this function on ``A.conj().T @ A``.
A: :class:`.LinearOperator` used for computation. Must be
diagonalizable. For arbitrary :class:`.LinearOperator`, call
this function on ``A.conj().T @ A``.
maxiter: Maximum number of power iterations to use. Default: 100
key: Jax PRNG key. Defaults to None, in which case a new key is created.
key: Jax PRNG key. Defaults to None, in which case a new key is
created.
Returns:
tuple: A tuple (mu, v) containing:
Expand All @@ -51,8 +58,62 @@ def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey
return mu, v


def valid_adjoint(
A: LinearOperator,
AT: LinearOperator,
eps: Optional[float] = 1e-7,
key: Optional[PRNGKey] = None,
) -> Union[bool, float]:
r"""Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`.
The test exploits the identity
.. math::
\mathbf{y}^T (A \mathbf{x}) = (\mathbf{y}^T A) \mathbf{x} =
(A^T \mathbf{y})^T \mathbf{x}
by computing :math:`\mathbf{u} = A \mathbf{x}` and
:math:`\mathbf{v} = A^T \mathbf{y}` for random :math:`\mathbf{x}`
and :math:`\mathbf{y}` and confirming that :math:`\| \mathbf{y}^T
\mathbf{u} - \mathbf{v}^T \mathbf{x} \|_2 < \epsilon` since
.. math::
\mathbf{y}^T \mathbf{u} = \mathbf{y}^T (A \mathbf{x}) =
(A^T \mathbf{y})^T \mathbf{x} = \mathbf{v}^T \mathbf{x}
when :math:`A^T` is a valid adjoint of :math:`A`. If :math:`A` is a
complex operator (with a complex `input_dtype`) then the test checks
whether `AT` is the Hermitian conjugate of `A`, with a test as above,
but with all the :math:`\cdot^T` replaced with :math:`\cdot^H`.
Args:
A: Primary :class:`.LinearOperator`.
AT: Adjoint :class:`.LinearOperator`.
eps: Error threshold for validation of `AT` as adjoint of `A`. If
None, the relative error is returned instead of a boolean value.
key: Jax PRNG key. Defaults to None, in which case a new key is
created.
Returns:
Boolean value indicating that validation passed, or relative error
of test, depending on type of parameter `eps`.
"""

x0, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype)
x1, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype)
y0 = A(x0)
y1 = AT(x1)
x1y0 = snp.dot(x1.ravel().conj(), y0.ravel())
y1x0 = snp.dot(y1.ravel().conj(), x0.ravel())
err = snp.linalg.norm(x1y0 - y1x0) / max(snp.linalg.norm(x1y0), snp.linalg.norm(y1x0))
if eps is None:
return err
else:
return err < eps


class Diagonal(LinearOperator):
"""Diagonal linear operator"""
"""Diagonal linear operator."""

def __init__(self, diagonal: JaxArray, input_dtype: Optional[DType] = None, **kwargs):
r"""
Expand Down Expand Up @@ -124,7 +185,7 @@ def __rmatmul__(self, x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockAr


class Sum(LinearOperator):
"""A linear operator for summing along an axis or set of axes"""
"""A linear operator for summing along an axis or set of axes."""

def __init__(
self,
Expand All @@ -138,12 +199,15 @@ def __init__(
Wraps :func:`jax.numpy.sum` as a :class:`.LinearOperator`.
Args:
sum_axis: The axis or set of axes to sum over. If `None`, sum is taken over all axes.
sum_axis: The axis or set of axes to sum over. If `None`,
sum is taken over all axes.
input_shape: Shape of input array.
input_dtype: `dtype` for input argument.
Defaults to `float32`. If this LinearOperator implements complex-valued operations,
this must be `complex64` for proper adjoint and gradient calculation.
jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator.
Defaults to `float32`. If this LinearOperator implements
complex-valued operations, this must be `complex64` for
proper adjoint and gradient calculation.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""

input_ndim = len(input_shape)
Expand Down
8 changes: 4 additions & 4 deletions scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import scico.numpy as snp
from scico.linop import CircularConvolve, Convolve
from scico.random import randint, randn, uniform
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test

SHAPE_SPECS = [
((16,), None, (3,)), # 1D
Expand Down Expand Up @@ -43,7 +43,8 @@ def test_eval(self, axes_shape_spec, input_dtype, jit):

Ax = A @ x

# check that a specific pixel of Ax computes an inner product between x and (flipped, padded, shifted) h
# check that a specific pixel of Ax computes an inner product between x and
# (flipped, padded, shifted) h
h_flipped = np.flip(h, range(-A.ndims, 0)) # flip only in the spatial dims (not batches)

x_inds = (...,) + tuple(
Expand All @@ -68,8 +69,7 @@ def test_adjoint(self, axes_shape_spec, input_dtype, jit):

A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit)

adjoint_AtA_test(A, self.key)
adjoint_AAt_test(A, self.key)
adjoint_test(A, self.key)

@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
Expand Down
8 changes: 3 additions & 5 deletions scico/test/linop/test_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from scico.linop import Convolve, ConvolveByX, LinearOperator
from scico.random import randn
from scico.test.linop.test_linop import AbsMatOp, adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import AbsMatOp, adjoint_test


class TestConvolve:
Expand Down Expand Up @@ -46,8 +46,7 @@ def test_adjoint(self, input_shape, mode, jit, input_dtype):

A = Convolve(h=psf, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)

adjoint_AtA_test(A, self.key)
adjoint_AAt_test(A, self.key)
adjoint_test(A, self.key)


class ConvolveTestObj:
Expand Down Expand Up @@ -212,8 +211,7 @@ def test_adjoint(self, input_shape, mode, jit, input_dtype):

A = ConvolveByX(x=x, input_shape=input_shape, input_dtype=input_dtype, mode=mode, jit=jit)

adjoint_AtA_test(A, self.key)
adjoint_AAt_test(A, self.key)
adjoint_test(A, self.key)


class ConvolveByXTestObj:
Expand Down
5 changes: 2 additions & 3 deletions scico/test/linop/test_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scico.linop import DFT
from scico.random import randn
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test


class TestDFT:
Expand Down Expand Up @@ -47,8 +47,7 @@ def test_adjoint(self, input_shape, pad_output, jit):
output_shape = None

F = DFT(input_shape=input_shape, output_shape=output_shape, jit=jit)
adjoint_AtA_test(F, self.key)
adjoint_AAt_test(F, self.key)
adjoint_test(F, self.key)

@pytest.mark.parametrize("input_shape", [(32,), (32, 48)])
@pytest.mark.parametrize("pad_output", [True, False])
Expand Down
5 changes: 2 additions & 3 deletions scico/test/linop/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scico.blockarray import BlockArray
from scico.linop import FiniteDifference
from scico.random import randn
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test


@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
Expand Down Expand Up @@ -58,8 +58,7 @@ def test_adjoint(self, input_shape, input_dtype, axes, jit):
A = FiniteDifference(
input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit
)
adjoint_AtA_test(A)
adjoint_AAt_test(A)
adjoint_test(A)


@pytest.mark.parametrize(
Expand Down
71 changes: 8 additions & 63 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,77 +18,24 @@
from scico.typing import PRNGKey


def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
"""Check the validity of A.conj().T as the adjoint for a LinearOperator A
Compares the quantity sum(x.conj() * A.conj().T @ A @ x) against
norm(A @ x)**2. If the adjoint is correct, these quantities should be equal.
def adjoint_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
"""Check the validity of A.conj().T as the adjoint for a LinearOperator A.
Args:
A : LinearOperator to test
key: PRNGKey for generating `x`.
rtol: Relative tolerance
"""

# Generate a signal in the domain of A
x, key = randn(A.input_shape, dtype=A.input_dtype, key=key)

Ax = A @ x

AtAx = A.conj().T @ Ax
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)

AtAx = A.H @ Ax
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)

AtAx = A.adj(Ax)
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)


def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
"""Check the validity of A as the adjoint for a LinearOperator A.conj().T
Compares the quantity sum(y.conj() * A @ A.conj().T @ y) against
norm(A.conj().T @ y)**2. If the adjoint is correct, these quantities should be equal.
Args:
A : LinearOperator to test
key: PRNGKey for generating `x`.
rtol: Relative tolerance
"""
# Generate a signal in the domain of A^T
y, key = randn(A.output_shape, dtype=A.output_dtype, key=key)

Aty = A.conj().T @ y
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)

Aty = A.H @ y
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)

Aty = A.adj(y)
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=rtol)
assert linop.valid_adjoint(A, A.H, rtol, key)


class AbsMatOp(linop.LinearOperator):
"""Simple LinearOperator subclass for testing purposes.
Similar to linop.MatrixOperator, but does not use the specialized MatrixOperator methods (.T, adj, etc).
Used to verify the LinearOperator interface.
Similar to linop.MatrixOperator, but does not use the specialized
MatrixOperator methods (.T, adj, etc). Used to verify the
LinearOperator interface.
"""

def __init__(self, A, adj_fn=None):
Expand Down Expand Up @@ -369,8 +316,7 @@ def test_adjoint(self, input_shape, diagonal_dtype):
diagonal, key = randn(input_shape, dtype=diagonal_dtype, key=self.key)
D = linop.Diagonal(diagonal=diagonal)

adjoint_AtA_test(D)
adjoint_AAt_test(D)
adjoint_test(D)

@pytest.mark.parametrize("operator", [op.add, op.sub])
@pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64])
Expand Down Expand Up @@ -544,8 +490,7 @@ def test_sum_eval(sumtestobj, axis):
def test_sum_adj(sumtestobj, axis):
x = sumtestobj.x
A = linop.Sum(input_shape=x.shape, input_dtype=x.dtype, sum_axis=axis)
adjoint_AtA_test(A)
adjoint_AAt_test(A)
adjoint_test(A)


@pytest.mark.parametrize("axis", (5, (1, 1), (0, 1, 2, 3, 4)))
Expand Down
5 changes: 2 additions & 3 deletions scico/test/linop/test_optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
radial_transverse_frequency,
)
from scico.random import randn
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test

prop_list = [AngularSpectrumPropagator, FresnelPropagator, FraunhoferPropagator]

Expand All @@ -29,8 +29,7 @@ def setup_method(self, method):
@pytest.mark.parametrize("prop", prop_list)
def test_prop_adjoint(self, prop, ndim):
A = prop(input_shape=(self.N,) * ndim, dx=self.dx, k0=self.k0, z=self.z)
adjoint_AtA_test(A, self.key)
adjoint_AAt_test(A, self.key)
adjoint_test(A, self.key)

@pytest.mark.parametrize("ndim", [1, 2])
def test_AS_inverse(self, ndim):
Expand Down
5 changes: 2 additions & 3 deletions scico/test/linop/test_radon_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test

try:
from scico.linop.radon_astra import ParallelBeamProjector
Expand Down Expand Up @@ -95,5 +95,4 @@ def test_adjoint_grad(testobj):

def test_adjoint(testobj):
A = testobj.A
adjoint_AAt_test(A, rtol=get_tol())
adjoint_AtA_test(A, rtol=get_tol())
adjoint_test(A, rtol=get_tol())
5 changes: 2 additions & 3 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import scico
import scico.numpy as snp
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test
from scico.test.linop.test_linop import adjoint_test
from scico.test.test_functional import prox_test

try:
Expand Down Expand Up @@ -61,8 +61,7 @@ def test_adjoint(Nx, Ny, num_angles, num_channels, is_3d):
im = make_im(Nx, Ny, is_3d)
A = make_A(im, num_angles, num_channels)

adjoint_AtA_test(A)
adjoint_AAt_test(A)
adjoint_test(A)


@pytest.mark.parametrize("Nx, Ny, num_angles, num_channels", (SMALL_INPUT,))
Expand Down
Loading

0 comments on commit 656f59d

Please sign in to comment.