-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve LinearOperator adjoint tests (#72)
* Docstring cleanup * Docstring fix * Add adjoint validation function * Add adjoint validation function * Fix for complex values * Update docstring * Replace adjoint tests * Switch to new adjoint test function in other linop module tests
- Loading branch information
Showing
12 changed files
with
109 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,17 +22,24 @@ | |
from scico.random import randn | ||
from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape | ||
|
||
__author__ = """Luke Pfister <[email protected]>""" | ||
__author__ = """\n""".join( | ||
["Luke Pfister <[email protected]>", "Brendt Wohlberg <[email protected]>"] | ||
) | ||
|
||
|
||
def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] = None): | ||
"""Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator` using power iteration. | ||
"""Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`. | ||
Compute largest eigenvalue of a diagonalizable | ||
:class:`.LinearOperator` using power iteration. | ||
Args: | ||
A: :class:`.LinearOperator` used for computation. Must be diagonalizable. | ||
For arbitrary :class:`.LinearOperator`, call this function on ``A.conj().T @ A``. | ||
A: :class:`.LinearOperator` used for computation. Must be | ||
diagonalizable. For arbitrary :class:`.LinearOperator`, call | ||
this function on ``A.conj().T @ A``. | ||
maxiter: Maximum number of power iterations to use. Default: 100 | ||
key: Jax PRNG key. Defaults to None, in which case a new key is created. | ||
key: Jax PRNG key. Defaults to None, in which case a new key is | ||
created. | ||
Returns: | ||
tuple: A tuple (mu, v) containing: | ||
|
@@ -51,8 +58,62 @@ def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey | |
return mu, v | ||
|
||
|
||
def valid_adjoint( | ||
A: LinearOperator, | ||
AT: LinearOperator, | ||
eps: Optional[float] = 1e-7, | ||
key: Optional[PRNGKey] = None, | ||
) -> Union[bool, float]: | ||
r"""Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`. | ||
The test exploits the identity | ||
.. math:: | ||
\mathbf{y}^T (A \mathbf{x}) = (\mathbf{y}^T A) \mathbf{x} = | ||
(A^T \mathbf{y})^T \mathbf{x} | ||
by computing :math:`\mathbf{u} = A \mathbf{x}` and | ||
:math:`\mathbf{v} = A^T \mathbf{y}` for random :math:`\mathbf{x}` | ||
and :math:`\mathbf{y}` and confirming that :math:`\| \mathbf{y}^T | ||
\mathbf{u} - \mathbf{v}^T \mathbf{x} \|_2 < \epsilon` since | ||
.. math:: | ||
\mathbf{y}^T \mathbf{u} = \mathbf{y}^T (A \mathbf{x}) = | ||
(A^T \mathbf{y})^T \mathbf{x} = \mathbf{v}^T \mathbf{x} | ||
when :math:`A^T` is a valid adjoint of :math:`A`. If :math:`A` is a | ||
complex operator (with a complex `input_dtype`) then the test checks | ||
whether `AT` is the Hermitian conjugate of `A`, with a test as above, | ||
but with all the :math:`\cdot^T` replaced with :math:`\cdot^H`. | ||
Args: | ||
A: Primary :class:`.LinearOperator`. | ||
AT: Adjoint :class:`.LinearOperator`. | ||
eps: Error threshold for validation of `AT` as adjoint of `A`. If | ||
None, the relative error is returned instead of a boolean value. | ||
key: Jax PRNG key. Defaults to None, in which case a new key is | ||
created. | ||
Returns: | ||
Boolean value indicating that validation passed, or relative error | ||
of test, depending on type of parameter `eps`. | ||
""" | ||
|
||
x0, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype) | ||
x1, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype) | ||
y0 = A(x0) | ||
y1 = AT(x1) | ||
x1y0 = snp.dot(x1.ravel().conj(), y0.ravel()) | ||
y1x0 = snp.dot(y1.ravel().conj(), x0.ravel()) | ||
err = snp.linalg.norm(x1y0 - y1x0) / max(snp.linalg.norm(x1y0), snp.linalg.norm(y1x0)) | ||
if eps is None: | ||
return err | ||
else: | ||
return err < eps | ||
|
||
|
||
class Diagonal(LinearOperator): | ||
"""Diagonal linear operator""" | ||
"""Diagonal linear operator.""" | ||
|
||
def __init__(self, diagonal: JaxArray, input_dtype: Optional[DType] = None, **kwargs): | ||
r""" | ||
|
@@ -124,7 +185,7 @@ def __rmatmul__(self, x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockAr | |
|
||
|
||
class Sum(LinearOperator): | ||
"""A linear operator for summing along an axis or set of axes""" | ||
"""A linear operator for summing along an axis or set of axes.""" | ||
|
||
def __init__( | ||
self, | ||
|
@@ -138,12 +199,15 @@ def __init__( | |
Wraps :func:`jax.numpy.sum` as a :class:`.LinearOperator`. | ||
Args: | ||
sum_axis: The axis or set of axes to sum over. If `None`, sum is taken over all axes. | ||
sum_axis: The axis or set of axes to sum over. If `None`, | ||
sum is taken over all axes. | ||
input_shape: Shape of input array. | ||
input_dtype: `dtype` for input argument. | ||
Defaults to `float32`. If this LinearOperator implements complex-valued operations, | ||
this must be `complex64` for proper adjoint and gradient calculation. | ||
jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. | ||
Defaults to `float32`. If this LinearOperator implements | ||
complex-valued operations, this must be `complex64` for | ||
proper adjoint and gradient calculation. | ||
jit: If ``True``, jit the evaluation, adjoint, and gram | ||
functions of the LinearOperator. | ||
""" | ||
|
||
input_ndim = len(input_shape) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.