From 9438cf545af8d580aae858b178dc845246223cec Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 6 Dec 2021 15:55:30 -0700 Subject: [PATCH] Docstring cleanup and improvement of style guide compliance (#119) * Various cleanup, mainly docstrings * Docstring cleanup * Typo bug fix * Use intersphinx refs * Some improvements * Remove explicit copy of jax docstrings * Fix indentation issues * Clean up docstrings * Minor edit * Style guide compliance * Style guide compliance * Cleanup * Fix docstring format problem * Cleanup and style compliance * Apply black manually * Docstring cleanup and style compliance * Style compliance * Improve formatting of returns specification * Alternative formatting of multiple returns * Formatting improvement * Minor changes * Format as code Co-authored-by: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> --- docs/source/conf.py | 4 + scico/__init__.py | 16 +- scico/_autograd.py | 178 ++++----------- scico/_generic_operators.py | 191 +++++++++------- scico/admm.py | 262 ++++++++++++---------- scico/blockarray.py | 337 +++++++++++++++-------------- scico/data/__init__.py | 16 +- scico/diagnostics.py | 56 ++--- scico/flax.py | 29 ++- scico/functional/__init__.py | 2 +- scico/functional/_denoiser.py | 24 ++- scico/functional/_flax.py | 10 +- scico/functional/_functional.py | 76 ++++--- scico/functional/_indicator.py | 42 ++-- scico/functional/_norm.py | 104 +++++---- scico/linop/_circconv.py | 75 ++++--- scico/linop/_convolve.py | 43 ++-- scico/linop/_dft.py | 25 ++- scico/linop/_diff.py | 53 +++-- scico/linop/_matrix.py | 39 ++-- scico/linop/_stack.py | 23 +- scico/linop/optics.py | 75 +++---- scico/linop/radon_astra.py | 12 +- scico/linop/radon_svmbir.py | 11 +- scico/loss.py | 72 ++++--- scico/math.py | 34 +-- scico/metric.py | 60 +++--- scico/numpy/__init__.py | 11 +- scico/numpy/_create.py | 73 ++++--- scico/numpy/_util.py | 14 +- scico/numpy/fft.py | 8 +- scico/numpy/linalg.py | 20 +- scico/operator/biconvolve.py | 28 +-- scico/pgm.py | 85 +++++--- scico/plot.py | 372 +++++++++++++++++--------------- scico/random.py | 55 ++--- scico/scipy/special.py | 8 +- scico/solver.py | 97 +++++---- scico/typing.py | 18 +- scico/util.py | 54 ++--- 40 files changed, 1479 insertions(+), 1233 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7d46fc9ad..81306eb98 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -265,6 +265,10 @@ def patched_parse(self): napoleon_use_ivar = True napoleon_use_rtype = False +# See https://github.com/sphinx-doc/sphinx/issues/9119 +# napoleon_custom_sections = [("Returns", "params_style")] + + graphviz_output_format = "svg" inheritance_graph_attrs = dict(rankdir="LR", fontsize=9, ratio="compress", bgcolor="transparent") inheritance_node_attrs = dict( diff --git a/scico/__init__.py b/scico/__init__.py index 8961631bc..561175872 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -4,7 +4,9 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""SCICO -- a Python package for solving the inverse problems that arise in scientific imaging applications.""" +"""Scientific Computational Imaging COde (SCICO) is a Python package for +solving the inverse problems that arise in scientific imaging applications. +""" __version__ = "0.0.2a1" @@ -16,13 +18,17 @@ # TODO remove this check as we get closer to release? import jax, jaxlib -if jax.__version__ < "0.2.19": +jax_ver_req = "0.2.19" +jaxlib_ver_req = "0.1.70" +if jax.__version__ < jax_ver_req: raise Exception( - f"""SCICO {__version__} requires jax>0.2.19; got {jax.__version__}; please upgrade jax.""" + f"SCICO {__version__} requires jax>={jax_ver_req}; got {jax.__version__}; " + "please upgrade jax." ) -if jaxlib.__version__ < "0.1.70": +if jaxlib.__version__ < jaxlib_ver_req: raise Exception( - f"""SCICO {__version__} requires jaxlib>0.1.70; got {jaxlib.__version__}; please upgrade jaxlib.""" + f"SCICO {__version__} requires jaxlib>={jaxlib_ver_req}; got {jaxlib.__version__}; " + "please upgrade jaxlib." ) from jax import custom_jvp, custom_vjp, jacfwd, jvp, linearize, vjp, hessian diff --git a/scico/_autograd.py b/scico/_autograd.py index 12084b6b6..81f41f1af 100644 --- a/scico/_autograd.py +++ b/scico/_autograd.py @@ -4,7 +4,7 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Automatric differentiation tools.""" +"""Automatic differentiation tools.""" from typing import Any, Callable, Sequence, Tuple, Union @@ -13,6 +13,21 @@ import jax.numpy as jnp +def _append_jax_docs(fn, jaxfn=None): + """Append the jax function docs. + + Given wrapper function ``fn``, concatenate its docstring with the + docstring of the wrapped jax function. + """ + + name = fn.__name__ + if jaxfn is None: + jaxfn = getattr(jax, name) + doc = " " + fn.__doc__.replace("\n ", "\n ") # deal with indentation differences + jaxdoc = "\n".join(jaxfn.__doc__.split("\n")[2:]) # strip initial lines + return doc + f"\n Docstring for :func:`jax.{name}`:\n\n" + jaxdoc + + def grad( fun: Callable, argnums: Union[int, Sequence[int]] = 0, @@ -20,38 +35,10 @@ def grad( holomorphic: bool = False, allow_int: bool = False, ) -> Callable: - """Creates a function which evaluates the gradient of ``fun``. - - :func:`scico.grad` differs from :func:`jax.grad` in that the output is conjugated. - - Docstring for :func:`jax.grad`: - - Args: - fun: Function to be differentiated. Its arguments at positions specified by - ``argnums`` should be arrays, scalars, or standard Python containers. - Argument arrays in the positions specified by ``argnums`` must be of - inexact (i.e., floating-point or complex) type. It - should return a scalar (which includes arrays with shape ``()`` but not - arrays with shape ``(1,)`` etc.) - argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default 0). - has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. If True, inputs and outputs must be complex. Default False. - allow_int: Optional, bool. Whether to allow differentiating with - respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. - - Returns: - A function with the same arguments as ``fun``, that evaluates the gradient - of ``fun``. If ``argnums`` is an integer then the gradient has the same - shape and type as the positional argument indicated by that integer. If - argnums is a tuple of integers, the gradient is a tuple of values with the - same shapes and types as the corresponding arguments. If ``has_aux`` is True - then a pair of (gradient, auxiliary_data) is returned. + """Create a function that evaluates the gradient of ``fun``. + :func:`scico.grad` differs from :func:`jax.grad` in that the output + is conjugated. """ jax_grad = jax.grad( @@ -69,6 +56,10 @@ def conjugated_grad(*args, **kwargs): return conjugated_grad_aux if has_aux else conjugated_grad +# Append docstring from original jax function +grad.__doc__ = _append_jax_docs(grad) + + def value_and_grad( fun: Callable, argnums: Union[int, Sequence[int]] = 0, @@ -76,36 +67,10 @@ def value_and_grad( holomorphic: bool = False, allow_int: bool = False, ) -> Callable[..., Tuple[Any, Any]]: - """Create a function which evaluates both ``fun`` and the gradient of ``fun``. - - :func:`scico.value_and_grad` differs from :func:`jax.value_and_grad` in that the gradient is conjugated. - - Docstring for :func:`jax.value_and_grad`: - - - Args: - fun: Function to be differentiated. Its arguments at positions specified by - ``argnums`` should be arrays, scalars, or standard Python containers. It - should return a scalar (which includes arrays with shape ``()`` but not - arrays with shape ``(1,)`` etc.) - argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default 0). - has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. If True, inputs and outputs must be complex. Default False. - allow_int: Optional, bool. Whether to allow differentiating with - respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. - - Returns: - A function with the same arguments as ``fun`` that evaluates both ``fun`` - and the gradient of ``fun`` and returns them as a pair (a two-element - tuple). If ``argnums`` is an integer then the gradient has the same shape - and type as the positional argument indicated by that integer. If argnums is - a sequence of integers, the gradient is a tuple of values with the same - shapes and types as the corresponding arguments. + """Create a function that evaluates both ``fun`` and its gradient. + + :func:`scico.value_and_grad` differs from :func:`jax.value_and_grad` + in that the gradient is conjugated. """ jax_val_grad = jax.value_and_grad( fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int @@ -124,50 +89,17 @@ def conjugated_value_and_grad(*args, **kwargs): return conjugated_value_and_grad_aux if has_aux else conjugated_value_and_grad +# Append docstring from original jax function +value_and_grad.__doc__ = _append_jax_docs(value_and_grad) + + def linear_adjoint(fun: Callable, *primals) -> Callable: - """Conjugate transpose a function that is promised to be linear. + """Conjugate transpose a function that is guaranteed to be linear. :func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose` - for complex inputs in that the conjugate transpose (adjoint) of `fun` is returned. - :func:`scico.linear_adjoint` is identical to :func:`jax.linear_transpose` - for real-valued primals. - - Docstring for :func:`jax.linear_transpose`: - - For linear functions, this transformation is equivalent to ``vjp``, but - avoids the overhead of computing the forward pass. - - The outputs of the transposed function will always have the exact same dtypes - as ``primals``, even if some values are truncated (e.g., from complex to - float, or from float64 to float32). To avoid truncation, use dtypes in - ``primals`` that match the full range of desired outputs from the transposed - function. Integer dtypes are not supported. - - Args: - fun: the linear function to be transposed. - *primals: a positional argument tuple of arrays, scalars, or (nested) - standard Python containers (tuples, lists, dicts, namedtuples, i.e., - pytrees) of those types used for evaluating the shape/dtype of - ``fun(*primals)``. These arguments may be real scalars/ndarrays, but that - is not required: only the ``shape`` and `dtype` attributes are accessed. - See below for an example. (Note that the duck-typed objects cannot be - namedtuples because those are treated as standard Python containers.) - - Returns: - A callable that calculates the transpose of ``fun``. Valid input into this - function must have the same shape/dtypes/structure as the result of - ``fun(*primals)``. Output will be a tuple, with the same - shape/dtypes/structure as ``primals``. - - >>> import jax - >>> import types - >>> import numpy as np - >>> - >>> f = lambda x, y: 0.5 * x - 0.5 * y - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) - >>> f_transpose = jax.linear_transpose(f, scalar, scalar) - >>> f_transpose(1.0) - (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32)) + for complex inputs in that the conjugate transpose (adjoint) of `fun` + is returned. :func:`scico.linear_adjoint` is identical to + :func:`jax.linear_transpose` for real-valued primals. """ def conj_fun(*primals): @@ -189,6 +121,10 @@ def conj_fun(*primals): return jax.linear_transpose(_fun, *_primals) +# Append docstring from original jax function +linear_adjoint.__doc__ = _append_jax_docs(linear_adjoint, jaxfn=jax.linear_transpose) + + def jacrev( fun: Callable, argnums: Union[int, Sequence[int]] = 0, @@ -197,36 +133,8 @@ def jacrev( ) -> Callable: """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. - :func:`scico.jacrev` differs from :func:`jax.jacrev` in that the output is conjugated. - - Docstring for :func:`jax.jacrev`: - - Args: - fun: Function whose Jacobian is to be computed. - argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default ``0``). - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. Default False. - allow_int: Optional, bool. Whether to allow differentiating with - respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. - - Returns: - A function with the same arguments as ``fun``, that evaluates the Jacobian of - ``fun`` using reverse-mode automatic differentiation. - - >>> import jax - >>> import jax.numpy as jnp - >>> - >>> def f(x): - ... return jnp.asarray( - ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) - ... - >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) # doctest: +SKIP - [[ 1. 0. 0. ] - [ 0. 0. 5. ] - [ 0. 16. -2. ] - [ 1.6209 0. 0.84147]] + :func:`scico.jacrev` differs from :func:`jax.jacrev` in that the + output is conjugated. """ jax_jacrev = jax.jacrev(fun=fun, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int) @@ -236,3 +144,7 @@ def conjugated_jacrev(*args, **kwargs): return jax.tree_map(jax.numpy.conj, tmp) return conjugated_jacrev + + +# Append docstring from original jax function +jacrev.__doc__ = _append_jax_docs(jacrev) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 588865941..15bdee593 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -89,20 +89,24 @@ def __init__( Args: input_shape: Shape of input array. output_shape: Shape of output array. - Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating - `self.__call__` on an input array of zeros. + Defaults to ``None``. If ``None``, `output_shape` is + determined by evaluating `self.__call__` on an input + array of zeros. eval_fn: Function used in evaluating this Operator. - Defaults to ``None``. If ``None``, then `self.__call__` must be defined in any - derived classes. + Defaults to ``None``. If ``None``, then `self.__call__` + must be defined in any derived classes. input_dtype: `dtype` for input argument. - Defaults to `float32`. If Operator implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. + Defaults to `float32`. If Operator implements + complex-valued operations, this must be `complex64` for + proper adjoint and gradient calculation. output_dtype: `dtype` for output argument. - Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating - `self.__call__` on an input array of zeros. - jit: If ``True``, call :meth:`Operator.jit()` on this Operator to jit the forward, - adjoint, and gram functions. Same as calling :meth:`Operator.jit` after - the Operator is created. + Defaults to ``None``. If ``None``, `output_shape` is + determined by evaluating `self.__call__` on an input + array of zeros. + jit: If ``True``, call :meth:`Operator.jit()` on this + Operator to jit the forward, adjoint, and gram functions. + Same as calling :meth:`Operator.jit` after the Operator + is created. """ #: Shape of input array or :class:`.BlockArray`. @@ -172,13 +176,14 @@ def jit(self): def __call__( self, x: Union[Operator, JaxArray, BlockArray] ) -> Union[Operator, JaxArray, BlockArray]: - r"""Evaluates this Operator at the point :math:`\mb{x}`. + r"""Evaluate this Operator at the point :math:`\mb{x}`. Args: - x: Point at which to evaluate this Operator. - If `x` is a :class:`DeviceArray` or :class:`.BlockArray`, must have - `shape == self.input_shape`. If `x` is a :class:`.Operator` or :class:`.LinearOperator`, must have - `x.output_shape == self.input_shape`. + x: Point at which to evaluate this Operator. If `x` is a + :class:`DeviceArray` or :class:`.BlockArray`, must have + `shape == self.input_shape`. If `x` is a + :class:`.Operator` or :class:`.LinearOperator`, must have + `x.output_shape == self.input_shape`. """ if isinstance(x, Operator): @@ -200,7 +205,8 @@ def __call__( return self._eval(x) else: raise ValueError( - f"""Cannot evaluate {type(self)} with input_shape={self.input_shape} on array with shape={x.shape}""" + f"Cannot evaluate {type(self)} with input_shape={self.input_shape} " + f"on array with shape={x.shape}" ) else: # What is the context under which this gets called? @@ -287,24 +293,27 @@ def jvp(self, primals, tangents): return jax.jvp(self, primals, tangents) def vjp(self, *primals): - """Computes a vector-Jacobian product. + """Compute a vector-Jacobian product. Args: - primals: Sequence of values at which the Jacobian is evaluated, - with length equal to the number of position arguments of `_eval`. + primals: Sequence of values at which the Jacobian is + evaluated, with length equal to the number of position + arguments of `_eval`. """ primals, self_vjp = jax.vjp(self, *primals) return primals, self_vjp def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator: - """Returns a new Operator with block argument `argnum` fixed to value `val`. + """Return a new Operator with fixed block argument `argnum`. - Args: - argnum: Index of block to freeze. Must be less than or equal to the - number of blocks in an input array. - val: Value to fix the `argnum`-th input to. + Return a new Operator with block argument `argnum` fixed to value + `val`. + Args: + argnum: Index of block to freeze. Must be less than or equal + to the number of blocks in an input array. + val: Value to fix the `argnum`-th input to. """ if not is_nested(self.input_shape): @@ -441,25 +450,28 @@ def __init__( Args: input_shape: Shape of input array. output_shape: Shape of output array. - Defaults to None. If None, ``output_shape`` is determined by evaluating - ``self.__call__`` on an input array of zeros. + Defaults to None. If None, ``output_shape`` is determined + by evaluating ``self.__call__`` on an input array of + zeros. eval_fn: Function used in evaluating this LinearOperator. - Defaults to None. If None, then ``self.__call__`` must be defined in any - derived classes. - adj_fn: Function used to evaluate the adjoint of this LinearOperator. - Defaults to None. If None, the adjoint + Defaults to None. If None, then ``self.__call__`` must + be defined in any derived classes. + adj_fn: Function used to evaluate the adjoint of this + LinearOperator. Defaults to None. If None, the adjoint is not set, and the :meth:`._set_adjoint` will be called silently at the first :meth:`.adj` call or can be called manually. input_dtype: `dtype` for input argument. - Defaults to `float32`. If ``LinearOperator`` implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. + Defaults to `float32`. If ``LinearOperator`` implements + complex-valued operations, this must be `complex64` for + proper adjoint and gradient calculation. output_dtype: `dtype` for output argument. - Defaults to None. If None, ``output_shape`` is determined by evaluating - ``self.__call__`` on an input array of zeros. - jit: If ``True``, call :meth:`.jit()` on this LinearOperator to jit the forward, - adjoint, and gram functions. Same as calling :meth:`.jit` after - the LinearOperator is created. + Defaults to None. If None, ``output_shape`` is determined + by evaluating ``self.__call__`` on an input array of + zeros. + jit: If ``True``, call :meth:`.jit()` on this LinearOperator + to jit the forward, adjoint, and gram functions. Same as + calling :meth:`.jit` after the LinearOperator is created. """ super().__init__( @@ -577,13 +589,14 @@ def __rmatmul__(self, other): def __call__( self, x: Union[LinearOperator, JaxArray, BlockArray] ) -> Union[LinearOperator, JaxArray, BlockArray]: - r"""Evaluates this LinearOperator at the point :math:`\mb{x}`. + r"""Evaluate this LinearOperator at the point :math:`\mb{x}`. Args: - x: Point at which to evaluate this ``LinearOperator``. - If ``x`` is a :class:`DeviceArray` or :class:`.BlockArray`, must have - ``shape == self.input_shape``. If ``x`` is a :class:`.LinearOperator`, must have - ``x.output_shape == self.input_shape``. + x: Point at which to evaluate this ``LinearOperator``. If + ``x`` is a :class:`DeviceArray` or :class:`.BlockArray`, + must have ``shape == self.input_shape``. If ``x`` is a + :class:`.LinearOperator`, must have + ``x.output_shape == self.input_shape``. """ if isinstance(x, LinearOperator): return ComposedLinearOperator(self, x) @@ -594,12 +607,16 @@ def __call__( def adj( self, y: Union[LinearOperator, JaxArray, BlockArray] ) -> Union[LinearOperator, JaxArray, BlockArray]: - """Computes the adjoint of this :class:`.LinearOperator` applied to input ``y`` + """Adjoint of this :class:`.LinearOperator`. + + Compute the adjoint of this :class:`.LinearOperator` applied to + input ``y``. Args: - y: Point at which to compute adjoint. - If `y` is :class:`DeviceArray` or :class:`.BlockArray`, must have - ``shape == self.output_shape``. If `y` is a :class:`.LinearOperator`, must have + y: Point at which to compute adjoint. If `y` is + :class:`DeviceArray` or :class:`.BlockArray`, must have + ``shape == self.output_shape``. If `y` is a + :class:`.LinearOperator`, must have ``y.output_shape == self.output_shape``. Returns: @@ -623,15 +640,17 @@ def adj( @property def T(self) -> LinearOperator: - """Returns a new :class:`LinearOperator` that implements the transpose of this :class:`LinearOperator`. - - For a real-valued LinearOperator ``A`` (``A.input_dtype=np.float32` or ``np.float64``), the - LinearOperator ``A.T`` implements the adjoint: ``A.T(y) == A.adj(y)``. - - For a complex-valued LinearOperator ``A`` (``A.input_dtype``=`np.complex64` or ``np.complex128``), the - LinearOperator ``A.T`` is not the adjoint. For the conjugate transpose, use ``.conj().T`` - or :meth:`.H`. - + """Transpose of this :class:`LinearOperator`. + + Return a new :class:`LinearOperator` that implements the + transpose of this :class:`LinearOperator`. For a real-valued + LinearOperator ``A`` (``A.input_dtype=np.float32` or + ``np.float64``), the LinearOperator ``A.T`` implements the + adjoint: ``A.T(y) == A.adj(y)``. For a complex-valued + LinearOperator ``A`` (``A.input_dtype``=`np.complex64` or + ``np.complex128``), the LinearOperator ``A.T`` is not the + adjoint. For the conjugate transpose, use ``.conj().T`` or + :meth:`.H`. """ if is_complex_dtype(self.input_dtype): return LinearOperator( @@ -654,13 +673,16 @@ def T(self) -> LinearOperator: @property def H(self) -> LinearOperator: - """Returns a new :class:`LinearOperator` that is the Hermitian transpose of this :class:`LinearOperator`. + """Hermitian transpose of this :class:`LinearOperator`. - For a real-valued LinearOperator ``A`` (``A.input_dtype=np.float32`` or ``np.float64``), the - LinearOperator ``A.H`` is equivalent to ``A.T``. - - For a complex-valued LinearOperator ``A`` (``A.input_dtype = np.complex64`` or ``np.complex128``), the - LinearOperator ``A.H`` implements the adjoint of ``A : A.H @ y == A.adj(y) == A.conj().T @ y)``. + Return a new :class:`LinearOperator` that is the Hermitian + transpose of this :class:`LinearOperator`. For a real-valued + LinearOperator ``A`` (``A.input_dtype=np.float32`` or + ``np.float64``), the LinearOperator ``A.H`` is equivalent to + ``A.T``. For a complex-valued LinearOperator ``A`` + (``A.input_dtype = np.complex64`` or ``np.complex128``), the + LinearOperator ``A.H`` implements the adjoint of + ``A : A.H @ y == A.adj(y) == A.conj().T @ y)``. For the non-conjugate transpose, see :meth:`.T`. """ @@ -674,7 +696,11 @@ def H(self) -> LinearOperator: ) def conj(self) -> LinearOperator: - """Returns a new :class:`.LinearOperator` ``Ac`` such that ``Ac(x) = conj(A)(x)``""" + """Complex conjugate of this :class:`LinearOperator`. + + Return a new :class:`.LinearOperator` ``Ac`` such that + ``Ac(x) = conj(A)(x)``. + """ # A.conj() x == (A @ x.conj()).conj() return LinearOperator( input_shape=self.input_shape, @@ -687,8 +713,11 @@ def conj(self) -> LinearOperator: @property def gram_op(self) -> LinearOperator: - """Returns a new :class:`.LinearOperator` ``G`` such that ``G(x) = A.adj(A(x)))``""" + """Gram operator of this :class:`LinearOperator`. + Return a new :class:`.LinearOperator` ``G`` such that + ``G(x) = A.adj(A(x)))``. + """ if self._gram is None: self._set_adjoint() @@ -704,13 +733,14 @@ def gram_op(self) -> LinearOperator: def gram( self, x: Union[LinearOperator, JaxArray, BlockArray] ) -> Union[LinearOperator, JaxArray, BlockArray]: - """Computes ``A.adj(A(x))`` + """Compute ``A.adj(A(x)).`` Args: - x: Point at which to evaluate the gram operator. - If ``x`` is a :class:`DeviceArray` or :class:`.BlockArray`, must have - ``shape == self.input_shape``. If ``x`` is a :class:`.LinearOperator`, must have - ``x.output_shape == self.input_shape``. + x: Point at which to evaluate the gram operator. If ``x`` is + a :class:`DeviceArray` or :class:`.BlockArray`, must have + ``shape == self.input_shape``. If ``x`` is a + :class:`.LinearOperator`, must have + ``x.output_shape == self.input_shape``. Returns: Result of ``A.adj(A(x))``. @@ -727,26 +757,29 @@ def __init__(self, A: LinearOperator, B: LinearOperator, jit: bool = False): r"""ComposedLinearOperator init method. A ComposedLinearOperator ``AB`` implements ``AB @ x == A @ B @ x``. - The LinearOperators ``A`` and ``B`` are stored as attributes of the ComposedLinearOperator. - - The LinearOperators ``A`` and ``B`` must have compatible shapes and dtypes: - ``A.input_shape == B.output_shape`` and ``A.input_dtype == B.input_dtype``. + The LinearOperators ``A`` and ``B`` are stored as attributes of + the ComposedLinearOperator. + The LinearOperators ``A`` and ``B`` must have compatible shapes + and dtypes: ``A.input_shape == B.output_shape`` and + ``A.input_dtype == B.input_dtype``. Args: A: First (left) LinearOperator. - B: Second (right) LinearOperator - jit: If ``True``, call :meth:`.jit()` on this LinearOperator to jit the forward, - adjoint, and gram functions. Same as calling :meth:`.jit` after - the LinearOperator is created. + B: Second (right) LinearOperator. + jit: If ``True``, call :meth:`.jit()` on this LinearOperator + to jit the forward, adjoint, and gram functions. Same as + calling :meth:`.jit` after the LinearOperator is created. """ if not isinstance(A, LinearOperator): raise TypeError( - f"The first argument to ComposedLinearOpeator must be a LinearOperator; got {type(A)}" + "The first argument to ComposedLinearOpeator must be a LinearOperator; " + f"got {type(A)}" ) if not isinstance(B, LinearOperator): raise TypeError( - f"The second argument to ComposedLinearOpeator must be a LinearOperator; got {type(B)}" + "The second argument to ComposedLinearOpeator must be a LinearOperator; " + f"got {type(B)}" ) if A.input_shape != B.output_shape: raise ValueError(f"Incompatable LinearOperator shapes {A.shape}, {B.shape}") diff --git a/scico/admm.py b/scico/admm.py index 155fdeab8..202be7d43 100644 --- a/scico/admm.py +++ b/scico/admm.py @@ -38,45 +38,50 @@ class SubproblemSolver: r"""Base class for solvers for the non-separable ADMM step. - The ADMM solver implemented by :class:`.ADMM` addresses a general problem form for - which one of the corresponding ADMM algorithm subproblems is separable into distinct - subproblems for each of the :math:`g_i`, and another that is non-separable, involving - function :math:`f` and a sum over :math:`\ell_2` norm terms involving all operators - :math:`C_i`. This class is a base class for solvers of the latter subproblem + The ADMM solver implemented by :class:`.ADMM` addresses a general + problem form for which one of the corresponding ADMM algorithm + subproblems is separable into distinct subproblems for each of the + :math:`g_i`, and another that is non-separable, involving function + :math:`f` and a sum over :math:`\ell_2` norm terms involving all + operators :math:`C_i`. This class is a base class for solvers of + the latter subproblem - .. math:: + .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} - \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 - + \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. Attributes: - admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. + admm (:class:`.ADMM`): ADMM solver object to which the + solver is attached. """ def internal_init(self, admm: "ADMM"): """Second stage initializer to be called by :meth:`.ADMM.__init__`. Args: - admm: Reference to :class:`.ADMM` object to which the :class:`.SubproblemSolver` - object is to be attached. + admm: Reference to :class:`.ADMM` object to which the + :class:`.SubproblemSolver` object is to be attached. """ self.admm = admm class GenericSubproblemSolver(SubproblemSolver): - """Solver for generic problem without special structure that can be exploited. + """Solver for generic problem without special structure. Attributes: - admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. - minimize_kwargs (dict): Dictionary of arguments for :func:`scico.solver.minimize`. + admm (:class:`.ADMM`): ADMM solver object to which the solver is + attached. + minimize_kwargs (dict): Dictionary of arguments for + :func:`scico.solver.minimize`. """ def __init__(self, minimize_kwargs: dict = {"options": {"maxiter": 100}}): """Initialize a :class:`GenericSubproblemSolver` object. Args: - minimize_kwargs : Dictionary of arguments for :func:`scico.solver.minimize`. + minimize_kwargs: Dictionary of arguments for + :func:`scico.solver.minimize`. """ self.minimize_kwargs = minimize_kwargs @@ -84,7 +89,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Solve the ADMM step. Args: - x0: Starting point. + x0: Initial value. Returns: Computed solution. @@ -109,54 +114,67 @@ def obj(x): class LinearSubproblemSolver(SubproblemSolver): - r"""Solver for the case where :code:`f` is a quadratic function of :math:`\mb{x}`. + r"""Solver for quadratic functionals. - Specialization of :class:`.SubproblemSolver` for the case where :code:`f` is an - :math:`\ell_2` or weighted :math:`\ell_2` norm, and :code:`f.A` is a linear - operator, so that the subproblem involves solving a linear equation. This requires - that ``f.functional`` be an instance of either :class:`.SquaredL2Loss` or - :class:`.WeightedSquaredL2Loss` and for the forward operator :code:`f.A` to be an - instance of :class:`.LinearOperator`. + Solver for the case in which :code:`f` is a quadratic function of + :math:`\mb{x}`. It is a specialization of :class:`.SubproblemSolver` + for the case where :code:`f` is an :math:`\ell_2` or weighted + :math:`\ell_2` norm, and :code:`f.A` is a linear operator, so that + the subproblem involves solving a linear equation. This requires that + ``f.functional`` be an instance of either :class:`.SquaredL2Loss` + or :class:`.WeightedSquaredL2Loss` and for the forward operator + :code:`f.A` to be an instance of :class:`.LinearOperator`. - In the case :class:`.WeightedSquaredL2Loss`, the :math:`\mb{x}`-update step is + In the case :class:`.WeightedSquaredL2Loss`, the + :math:`\mb{x}`-update step is - .. math:: - \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A x}_W^2 + - \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, + .. math:: + + \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} + \norm{\mb{y} - A x}_W^2 + \sum_i \frac{\rho_i}{2} + \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, where :math:`W` is the weighting :class:`.Diagonal` from the - :class:`.WeightedSquaredL2Loss` instance. This update step reduces to the - solution of the linear system + :class:`.WeightedSquaredL2Loss` instance. This update step + reduces to the solution of the linear system - .. math:: - \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right) \mb{x}^{(k+1)} = \; - A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. + .. math:: + + \left(A^H W A + \sum_{i=1}^N \rho_i C_i^H C_i \right) + \mb{x}^{(k+1)} = \; + A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - + \mb{u}^{(k)}_i) \;. - In the case :class:`.SquaredL2Loss` :math:`W` is replaced with the :class:`Identity` operator. + In the case of :class:`.SquaredL2Loss`, :math:`W` is set to + the :class:`Identity` operator. Attributes: - admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. + admm (:class:`.ADMM`): ADMM solver object to which the solver is + attached. cg_kwargs (dict): Dictionary of arguments for CG solver. cg (func): CG solver function (:func:`scico.solver.cg` or - :func:`jax.scipy.sparse.linalg.cg`) - lhs (type): Function implementing the linear operator needed for the :math:`\mb{x}` - update step + :func:`jax.scipy.sparse.linalg.cg`) lhs (type): Function + implementing the linear operator needed for the + :math:`\mb{x}` update step. """ def __init__(self, cg_kwargs: dict = {"maxiter": 100}, cg_function: str = "scico"): """Initialize a :class:`LinearSubproblemSolver` object. Args: - cg_kwargs : Dictionary of arguments for CG solver. See :func:`scico.solver.cg` or - :func:`jax.scipy.sparse.linalg.cg`, documentation, including how to specify - a preconditioner. - cg_function: String indicating which CG implementation to use. One of "jax" or - "scico"; default "scico". If "scico", uses :func:`scico.solver.cg`. If - "jax", uses :func:`jax.scipy.sparse.linalg.cg`. The "jax" option is slower - on small-scale problems or problems involving external functions, but - can be differentiated through. The "scico" option is faster on small-scale - problems, but slower on large-scale problems where the forward operator is - written entirely in jax. + cg_kwargs: Dictionary of arguments for CG solver. See + :func:`scico.solver.cg` or + :func:`jax.scipy.sparse.linalg.cg`, documentation, + including how to specify a preconditioner. + cg_function: String indicating which CG implementation to + use. One of "jax" or "scico"; default "scico". If + "scico", uses :func:`scico.solver.cg`. If "jax", uses + :func:`jax.scipy.sparse.linalg.cg`. The "jax" option is + slower on small-scale problems or problems involving + external functions, but can be differentiated through. + The "scico" option is faster on small-scale problems, but + slower on large-scale problems where the forward + operator is written entirely in jax. """ self.cg_kwargs = cg_kwargs if cg_function == "scico": @@ -202,7 +220,8 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]: .. math:: - A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. + A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - + \mb{u}^{(k)}_i) \;. Returns: Computed solution. @@ -225,7 +244,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Solve the ADMM step. Args: - x0: Starting point. + x0: Initial value. Returns: Computed solution. @@ -239,15 +258,19 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: class CircularConvolveSolver(LinearSubproblemSolver): r"""Solver for linear operators diagonalized in the DFT domain. - Specialization of :class:`.LinearSubproblemSolver` for the case where :code:`f` is an - instance of :class:`.SquaredL2Loss`, the forward operator :code:`f.A` is either - an instance of :class:`.Identity` or :class:`.CircularConvolve`, and the :code:`C_i` are - all instances of :class:`.Identity` or :class:`.CircularConvolve`. None of the instances of + Specialization of :class:`.LinearSubproblemSolver` for the case + where :code:`f` is an instance of :class:`.SquaredL2Loss`, the + forward operator :code:`f.A` is either an instance of + :class:`.Identity` or :class:`.CircularConvolve`, and the + :code:`C_i` are all instances of :class:`.Identity` or + :class:`.CircularConvolve`. None of the instances of :class:`.CircularConvolve` may sum over any of their axes. Attributes: - admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. - lhs_f (array): Left hand side, in the DFT domain, of the linear equation to be solved. + admm (:class:`.ADMM`): ADMM solver object to which the solver is + attached. + lhs_f (array): Left hand side, in the DFT domain, of the linear + equation to be solved. """ def __init__(self): @@ -284,7 +307,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Solve the ADMM step. Args: - x0: Starting point. + x0: Initial value. Returns: Computed solution. @@ -310,49 +333,58 @@ class ADMM: .. math:: \argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;, - where :math:`f` and the :math:`g_i` are instances of :class:`.Functional`, - and the :math:`C_i` are :class:`.LinearOperator`. + where :math:`f` and the :math:`g_i` are instances of + :class:`.Functional`, and the :math:`C_i` are + :class:`.LinearOperator`. The optimization problem is solved by introducing the splitting :math:`\mb{z}_i = C_i \mb{x}` and solving .. math:: - \argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \; - \text{such that}\; C_i \mb{x} = \mb{z}_i \;, + \argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N + g_i(\mb{z}_i) \; \text{such that}\; C_i \mb{x} = \mb{z}_i \;, - via an ADMM algorithm :cite:`glowinski-1975-approximation` :cite:`gabay-1976-dual` - :cite:`boyd-2010-distributed`. consisting of the iterations (see :meth:`step`) + via an ADMM algorithm :cite:`glowinski-1975-approximation` + :cite:`gabay-1976-dual` :cite:`boyd-2010-distributed` consisting of + the iterations (see :meth:`step`) .. math:: \begin{aligned} - \mb{x}^{(k+1)} &= \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} - \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \\ - \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + \frac{\rho_i}{2} - \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\ - \mb{u}_i^{(k+1)} &= \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i \; . + \mb{x}^{(k+1)} &= \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i + \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i + \mb{x}}_2^2 \\ + \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + + \frac{\rho_i}{2} + \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\ + \mb{u}_i^{(k+1)} &= \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - + \mb{z}^{(k+1)}_i \; . \end{aligned} Attributes: - f (:class:`.Functional`): Functional :math:`f` (usually a :class:`.Loss`) + f (:class:`.Functional`): Functional :math:`f` (usually a + :class:`.Loss`) g_list (list of :class:`.Functional`): List of :math:`g_i` - functionals. Must be same length as :code:`C_list` and :code:`rho_list`. - C_list (list of :class:`.LinearOperator`): List of :math:`C_i` operators. + functionals. Must be same length as :code:`C_list` and + :code:`rho_list`. + C_list (list of :class:`.LinearOperator`): List of :math:`C_i` + operators. itnum (int): Iteration counter. maxiter (int): Number of ADMM outer-loop iterations. timer (:class:`.Timer`): Iteration timer. - rho_list (list of scalars): List of :math:`\rho_i` penalty parameters. - Must be same length as :code:`C_list` and :code:`g_list`. + rho_list (list of scalars): List of :math:`\rho_i` penalty + parameters. Must be same length as :code:`C_list` and + :code:`g_list`. alpha (float): Relaxation parameter. u_list (list of array-like): List of scaled Lagrange multipliers :math:`\mb{u}_i` at current iteration. - x (array-like): Solution + x (array-like): Solution. subproblem_solver (:class:`.SubproblemSolver`): Solver for :math:`\mb{x}`-update step. - z_list (list of array-like): List of auxiliary variables :math:`\mb{z}_i` - at current iteration. - z_list_old (list of array-like): List of auxiliary variables :math:`\mb{z}_i` - at previous iteration. + z_list (list of array-like): List of auxiliary variables + :math:`\mb{z}_i` at current iteration. + z_list_old (list of array-like): List of auxiliary variables + :math:`\mb{z}_i` at previous iteration. """ def __init__( @@ -371,17 +403,17 @@ def __init__( r"""Initialize an :class:`ADMM` object. Args: - f : Functional :math:`f` (usually a loss function) - g_list : List of :math:`g_i` functionals. Must be same length - as :code:`C_list` and :code:`rho_list` - C_list : List of :math:`C_i` operators - rho_list : List of :math:`\rho_i` penalty parameters. + f: Functional :math:`f` (usually a loss function). + g_list: List of :math:`g_i` functionals. Must be same length + as :code:`C_list` and :code:`rho_list`. + C_list: List of :math:`C_i` operators. + rho_list: List of :math:`\rho_i` penalty parameters. Must be same length as :code:`C_list` and :code:`g_list`. alpha: Relaxation parameter. No relaxation for default 1.0. - x0 : Starting point for :math:`\mb{x}`. If None, defaults to + x0: Initial value for :math:`\mb{x}`. If None, defaults to an array of zeros. - maxiter : Number of ADMM outer-loop iterations. Default: 100. - subproblem_solver : Solver for :math:`\mb{x}`-update step. + maxiter: Number of ADMM outer-loop iterations. Default: 100. + subproblem_solver: Solver for :math:`\mb{x}`-update step. Defaults to ``None``, which implies use of an instance of :class:`GenericSubproblemSolver`. verbose: Flag indicating whether iteration statistics should @@ -471,19 +503,21 @@ def objective( ) -> float: r"""Evaluate the objective function. - .. math:: - f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) + Evaluate the objective function + .. math:: + f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;. Args: - x: Point at which to evaluate objective function. If `None`, the objective is - evaluated at the current iterate :code:`self.x` - z_list: Point at which to evaluate objective function. If `None`, the objective is - evaluated at the current iterate :code:`self.z_list` - + x: Point at which to evaluate objective function. If `None`, + the objective is evaluated at the current iterate + :code:`self.x`. + z_list: Point at which to evaluate objective function. If + `None`, the objective is evaluated at the current iterate + :code:`self.z_list`. Returns: - scalar: Current value of the objective function + scalar: Current value of the objective function. """ if (x is None) != (z_list is None): raise ValueError("Both or neither of x and z_list must be supplied") @@ -500,15 +534,19 @@ def objective( def norm_primal_residual(self, x: Optional[Union[JaxArray, BlockArray]] = None) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. + Compute the :math:`\ell_2` norm of the primal residual + .. math:: - \left(\sum_{i=1}^N \norm{C_i \mb{x} - \mb{z}_i}_2^2\right)^{1/2} + \left(\sum_{i=1}^N \norm{C_i \mb{x} - + \mb{z}_i}_2^2\right)^{1/2} \;. Args: x: Point at which to evaluate primal residual. - If `None`, the primal residual is evaluated at the current iterate :code:`self.x` + If `None`, the primal residual is evaluated at the + current iterate :code:`self.x`. Returns: - Current value of primal residual + Current value of primal residual. """ if x is None: x = self.x @@ -521,11 +559,14 @@ def norm_primal_residual(self, x: Optional[Union[JaxArray, BlockArray]] = None) def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. + Compute the :math:`\ell_2` norm of the dual residual + .. math:: - \left(\sum_{i=1}^N \norm{\mb{z}^{(k)}_i - \mb{z}^{(k-1)}_i}_2^2\right)^{1/2} + \left(\sum_{i=1}^N \norm{\mb{z}^{(k)}_i - + \mb{z}^{(k-1)}_i}_2^2\right)^{1/2} \;. Returns: - Current value of dual residual + Current value of dual residual. """ out = 0.0 @@ -536,15 +577,16 @@ def norm_dual_residual(self) -> float: def z_init(self, x0: Union[JaxArray, BlockArray]): r"""Initialize auxiliary variables :math:`\mb{z}_i`. - Initializes to + Initialized to .. math:: - \mb{z}_i = C_i \mb{x}^{(0)} + \mb{z}_i = C_i \mb{x}^{(0)} \;. - :code:`z_list` and :code:`z_list_old` are initialized to the same value. + :code:`z_list` and :code:`z_list_old` are initialized to the same + value. Args: - x0: Starting point for :math:`\mb{x}` + x0: Initial value of :math:`\mb{x}`. """ z_list = [Ci(x0) for Ci in self.C_list] z_list_old = z_list.copy() @@ -553,14 +595,13 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): def u_init(self, x0: Union[JaxArray, BlockArray]): r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`. - Initializes to + Initialized to .. math:: - \mb{u}_i = C_i \mb{x}^{(0)} - + \mb{u}_i = C_i \mb{x}^{(0)} \;. Args: - x0: Starting point for :math:`\mb{x}` + x0: Initial value of :math:`\mb{x}`. """ u_list = [snp.zeros(Ci.output_shape, dtype=Ci.output_dtype) for Ci in self.C_list] return u_list @@ -576,7 +617,9 @@ def step(self): \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. - The auxiliary variables are updated according to + Update auxiliary variables :math:`\mb{z}_i` and scaled Lagrange + multipliers :math:`\mb{u}_i`. The auxiliary variables are updated + according to .. math:: \begin{aligned} @@ -618,8 +661,9 @@ def solve( Run the ADMM algorithm for a total of ``self.maxiter`` iterations. Args: - callback: An optional callback function, taking an a single argument of type - :class:`ADMM`, that is called at the end of every iteration. + callback: An optional callback function, taking an a single + argument of type :class:`ADMM`, that is called at the end + of every iteration. Returns: Computed solution. diff --git a/scico/blockarray.py b/scico/blockarray.py index 041adb077..dcc18fe06 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -15,17 +15,19 @@ >>> import jax.numpy The class :class:`.BlockArray` is a `jagged array -`_ that aims to mimic the :class:`numpy.ndarray` -interface where appropriate. +`_ that aims to mimic the +:class:`numpy.ndarray` interface where appropriate. -A :class:`.BlockArray` object consists of a tuple of `DeviceArray` objects that share their memory -buffers with non-overlapping, contiguous regions of a common one-dimensional `DeviceArray`. -A :class:`.BlockArray` contains the following size attributes: +A :class:`.BlockArray` object consists of a tuple of `DeviceArray` +objects that share their memory buffers with non-overlapping, contiguous +regions of a common one-dimensional `DeviceArray`. A :class:`.BlockArray` +contains the following size attributes: * `shape`: A tuple of tuples containing component dimensions. -* `size`: The sum of the size of each component block; this is the length of the underlying - one-dimensional `DeviceArray`. -* `num_blocks`: The number of components (blocks) that comprise the :class:`.BlockArray`. +* `size`: The sum of the size of each component block; this is the length + of the underlying one-dimensional `DeviceArray`. +* `num_blocks`: The number of components (blocks) that comprise the + :class:`.BlockArray`. Motivating Example @@ -33,17 +35,20 @@ Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. -We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions, -generating two new arrays: :math:`\mb{x}_h \in \mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in -\mathbb{R}^{(n-1) \times m}`. +We compute the discrete differences of :math:`\mb{x}` in the horizontal +and vertical directions, generating two new arrays: :math:`\mb{x}_h \in +\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) +\times m}`. -As these arrays are of different sizes, we cannot combine them into a single `ndarray`. Instead, -we might vectorize each array and concatenate the resulting vectors, leading to :math:`\mb{\bar{x}} -\in \mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional `ndarray`. -Unfortunately, this makes it hard to access the individual components :math:`\mb{x}_h` and -:math:`\mb{x}_v`. +As these arrays are of different shapes, we cannot combine them into a +single `ndarray`. Instead, we might vectorize each array and concatenate +the resulting vectors, leading to :math:`\mb{\bar{x}} \in +\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional +`ndarray`. Unfortunately, this makes it hard to access the individual +components :math:`\mb{x}_h` and :math:`\mb{x}_v`. -Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = [\mb{x}_h, \mb{x}_v]` +Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = +[\mb{x}_h, \mb{x}_v]` :: @@ -70,7 +75,6 @@ Constructing a BlockArray ------------------------- - Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) #################################################################### .. doctest:: @@ -87,13 +91,16 @@ >>> X.num_blocks 2 -| While :func:`.BlockArray.array` will accept either `ndarray` or `DeviceArray` as input, the - resulting :class:`.BlockArray` will be backed by a `DeviceArray` memory buffer. +While :func:`.BlockArray.array` will accept either `ndarray` or +`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed +by a `DeviceArray` memory buffer. + +**Note**: constructing a :class:`.BlockArray` always involves a copy to +a new `DeviceArray` memory buffer. -| **Note**: constructing a :class:`.BlockArray` always involves a copy to a new `DeviceArray` memory buffer. +**Note**: by default, the resulting :class:`.BlockArray` is cast to +single precision and will have dtype `float32` or `complex64`. -| **Note**: by default, the resulting :class:`.BlockArray` cast to single precision and will have - dtype `float32` or `complex64`. Construct from a single vector and tuple of shapes ################################################## @@ -117,11 +124,12 @@ `_ -The index must be of the form [ibk] or [ibk,idx], -where `ibk` is the index of the block to be updated, and `idx` is a -general index of the elements to be updated in that block. In particular, `ibk` -cannot be a `slice`. The general index `idx` can be omitted, in which case -an entire block is updated. +The index must be of the form [ibk] or [ibk,idx], where `ibk` is the +index of the block to be updated, and `idx` is a general index of the +elements to be updated in that block. In particular, `ibk` cannot be a +`slice`. The general index `idx` can be omitted, in which case an entire +block is updated. + ============================== ============================================== Alternate syntax Equivalent in-place expression @@ -160,7 +168,8 @@ Operations with BlockArrays with same number of blocks ****************************************************** -Let :math:`\mb{y}` be a BlockArray with the same number of blocks as :math:`\mb{x}`. +Let :math:`\mb{y}` be a BlockArray with the same number of blocks as +:math:`\mb{x}`. .. math:: \mb{x} + \mb{y} @@ -170,8 +179,8 @@ \mb{x}[1] + \mb{y}[1] \\ \end{bmatrix} -This operation depends on pair of blocks from :math:`\mb{x}` and :math:`\mb{y}` -being broadcastable against each other. +This operation depends on pair of blocks from :math:`\mb{x}` and +:math:`\mb{y}` being broadcastable against each other. @@ -202,7 +211,8 @@ Operations with a 1D `ndarray` of size equal to `num_blocks` ************************************************************ -The *i*\th scalar is added to the *i*\th block of the :class:`.BlockArray`: +The *i*\th scalar is added to the *i*\th block of the +:class:`.BlockArray`: .. math:: \mb{x} @@ -228,8 +238,9 @@ Operations with an ndarray of `size` equal to :class:`.BlockArray` size *********************************************************************** -We first cast the `ndarray` to a BlockArray with same shape as :math:`\mb{x}`, then apply the -operation on the resulting BlockArrays. With ``y.size = x.size``, we have: +We first cast the `ndarray` to a BlockArray with same shape as +:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. +With ``y.size = x.size``, we have: .. math:: \mb{x} @@ -241,8 +252,9 @@ \mb{x}[1] + \mb{y}[1]\\ \end{bmatrix} -Equivalently, the BlockArray is first flattened, then added to the flattened `ndarray`, and the -result is reformed into a BlockArray with the same shape as :math:`\mb{x}` +Equivalently, the BlockArray is first flattened, then added to the +flattened `ndarray`, and the result is reformed into a BlockArray with +the same shape as :math:`\mb{x}` @@ -255,8 +267,8 @@ The matmul is computed between each block of the two BlockArrays. -The BlockArrays must have the same number of blocks, and each pair of blocks -must be broadcastable. +The BlockArrays must have the same number of blocks, and each pair of +blocks must be broadcastable. .. math:: \mb{x} @ \mb{y} @@ -280,7 +292,8 @@ .. todo:: Improve this -The :class:`.Operator` and :class:`.LinearOperator` classes are designed to work on :class:`.BlockArray`. The shapes must conform: +The :class:`.Operator` and :class:`.LinearOperator` classes are designed +to work on :class:`.BlockArray`. The shapes must conform: :: @@ -292,26 +305,28 @@ NumPy ufuncs ############ -`NumPy universal functions (ufuncs) `_ are -functions that operate on an `ndarray` on an element-by-element fashion and support array -broadcasting. Examples of ufuncs are ``abs``, ``sign``, ``conj``, and ``exp``. +`NumPy universal functions (ufuncs) `_ +are functions that operate on an `ndarray` on an element-by-element +fashion and support array broadcasting. Examples of ufuncs are ``abs``, +``sign``, ``conj``, and ``exp``. -The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` module. -However, as JAX does not support subclassing of `DeviceArray`, the JAX ufuncs -cannot be used on :class:`.BlockArray`. As a workaround, we have wrapped several -JAX ufuncs for use on :class:`.BlockArray`; these are located in the -:mod:`scico.numpy` module. +The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` +module. However, as JAX does not support subclassing of `DeviceArray`, +the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, +we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these +are defined in the :mod:`scico.numpy` module. Reductions ########## -Reductions are functions that take an array-like as an input and return an array of lower -dimension. Examples include ``mean``, ``sum``, ``norm``. BlockArray reductions are located in the -:mod:`scico.numpy` module +Reductions are functions that take an array-like as an input and return +an array of lower dimension. Examples include ``mean``, ``sum``, ``norm``. +BlockArray reductions are located in the :mod:`scico.numpy` module -:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where possible, but -cannot provide a one-to-one match as the block components may be of different size. +:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where +possible, but cannot provide a one-to-one match as the block components +may be of different size. Consider the example BlockArray @@ -342,15 +357,17 @@ - If no axis is specified, the reduction is applied to the flattened array: + If no axis is specified, the reduction is applied to the flattened + array: .. doctest:: >>> snp.sum(x, axis=None).item() 8.0 - Reducing along the 0-th axis crushes the `BlockArray` down into a single `DeviceArray` - and requires all blocks to have the same shape otherwise, an error is raised. + Reducing along the 0-th axis crushes the `BlockArray` down into a + single `DeviceArray` and requires all blocks to have the same shape + otherwise, an error is raised. .. doctest:: @@ -363,7 +380,8 @@ DeviceArray([[3., 3.], [3., 3.]], dtype=float32) - Reducing along axis :math:`n` is equivalent to reducing each component along axis :math:`n-1`: + Reducing along axis :math:`n` is equivalent to reducing each component + along axis :math:`n-1`: .. math:: \text{sum}(x, axis=1) = \begin{bmatrix} @@ -377,8 +395,9 @@ \end{bmatrix} - If a component does not have axis :math:`n-1`, the reduction is not applied to that component. In this example, - ``x[1].ndim == 1``, so no reduction is applied to block ``x[1]``. + If a component does not have axis :math:`n-1`, the reduction is not + applied to that component. In this example, ``x[1].ndim == 1``, so no + reduction is applied to block ``x[1]``. .. math:: \text{sum}(x, axis=2) = \begin{bmatrix} @@ -437,23 +456,8 @@ def atleast_1d(*arys): """Convert inputs to arrays with at least one dimension. - BlockArrays are returned unmodified. - - LAX-backend implementation of :func:`atleast_1d`. - - The JAX version of this function will return a copy rather than a view of the input. - - *Original docstring below.* - - Scalar inputs are converted to 1-dimensional arrays, whilst - higher-dimensional inputs are preserved. - - Args: - arys1, arys2, ... : One or more input arrays (array_like). - - Returns: - An array, or list of arrays, each with ``a.ndim >= 1``. Copies are made only if - necessary. + A wrapper for :func:`jax.numpy.atleast_1d` that acts as usual on + ndarrays and DeviceArrays, and returns BlockArrays unmodified. """ if len(arys) == 1: @@ -469,20 +473,31 @@ def atleast_1d(*arys): return out +# Append docstring from original jax.numpy function +atleast_1d.__doc__ = ( + atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences + + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" + + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) +) + + def reshape( a: Union[JaxArray, BlockArray], newshape: Union[Shape, BlockShape] ) -> Union[JaxArray, BlockArray]: - """Gives a new shape to an array without changing its data. + """Change the shape of an array without changing its data. Args: - a : Array to be reshaped. - newshape: The new shape should be compatible with the original shape. If an integer, - then the result will be a 1-D array of that length. One shape dimension can be -1. - In this case, the value is inferred from the length of the array and remaining - dimensions. If a tuple of tuple of ints, a :class:`.BlockArray` is returned. + a: Array to be reshaped. + newshape: The new shape should be compatible with the original + shape. If an integer, then the result will be a 1-D array of + that length. One shape dimension can be -1. In this case, + the value is inferred from the length of the array and + remaining dimensions. If a tuple of tuple of ints, a + :class:`.BlockArray` is returned. Returns: - The reshaped array. Unlike :func:`numpy.reshape`, a copy is always returned. + The reshaped array. Unlike :func:`numpy.reshape`, a copy is + always returned. """ if is_nested(newshape): @@ -493,13 +508,13 @@ def reshape( def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: - r"""Computes the 'sizes' of (possibly nested) block shapes. + r"""Compute the 'sizes' of (possibly nested) block shapes. This function computes ``block_sizes(z.shape) == (_.size for _ in z)`` Args: - shape: A shape tuple; possibly containing nested tuples. + shape: A shape tuple; possibly containing nested tuples. Examples: @@ -549,7 +564,7 @@ def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: def _flatten_blockarrays(inp, *args, **kwargs): - """Flattens any blockarrays present in inp, args, or kwargs""" + """Flatten any blockarrays present in inp, args, or kwargs.""" def _flatten_if_blockarray(inp): if isinstance(inp, BlockArray): @@ -564,7 +579,7 @@ def _flatten_if_blockarray(inp): def _block_array_ufunc_wrapper(func): - """Wraps a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`""" + """Wrap a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`.""" @wraps(func) def wrapper(inp, *args, **kwargs): @@ -592,8 +607,8 @@ def wrapper(inp, *args, **kwargs): def _block_array_reduction_wrapper(func): - """Wraps a reduction (eg sum, norm) to allow for joint operation on `DeviceArray` and - `BlockArray`""" + """Wrap a reduction (eg. sum, norm) to allow for joint operation on + `DeviceArray` and `BlockArray`.""" @wraps(func) def wrapper(inp, *args, axis=None, **kwargs): @@ -686,8 +701,8 @@ def wrapper(self, other): def _block_array_binary_op_wrapper(func): - """Returns a decorator that performs type and shape checking for :class:`.BlockArray` - arithmetic + """Return a decorator that performs type and shape checking for + :class:`.BlockArray` arithmetic. """ @wraps(func) @@ -738,7 +753,6 @@ class _AbstractBlockArray(core.ShapedArray): """Abstract BlockArray class for JAX tracing. See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - """ array_abstraction_level = 0 # Same as jax.core.ConcreteArray @@ -771,10 +785,11 @@ def __init__(self, shapes, dtype): class BlockArray: """A tuple of :class:`jax.interpreters.xla.DeviceArray` objects. - A tuple of `DeviceArray` objects that all share their memory buffers with - non-overlapping, contiguous regions of a common one-dimensional `DeviceArray`. - It can be used as the common one-dimensional array via the :func:`BlockArray.ravel` - method, or individual component arrays can be accessed individually. + A tuple of `DeviceArray` objects that all share their memory buffers + with non-overlapping, contiguous regions of a common one-dimensional + `DeviceArray`. It can be used as the common one-dimensional array via + the :func:`BlockArray.ravel` method, or individual component arrays + can be accessed individually. """ # Ensure we use BlockArray.__radd__,__rmul__, etc for binary operations of the form @@ -786,7 +801,7 @@ def __init__(self, aval: _AbstractBlockArray, data: JaxArray): """BlockArray init method. Args: - aval: `Abstract value`_ associated to this array (shape+dtype+weak_type) + aval: `Abstract value`_ associated with this array (shape+dtype+weak_type) data: The underlying contiguous, flattened `DeviceArray`. .. _Abstract value: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html @@ -893,7 +908,7 @@ def __iter__(self) -> Iterator[int]: @property def blocks(self) -> Iterator[int]: - """Returns an iterator yielding component blocks.""" + """Return an iterator yielding component blocks.""" return self.__iter__() @property @@ -903,12 +918,13 @@ def bndpos(self) -> np.ndarray: @property def dtype(self) -> DType: - """Array dtype""" + """Array dtype.""" return self._data.dtype @property def device_buffer(self) -> Buffer: - """The :class:`jaxlib.xla_extension.Buffer` that backs the underlying data array""" + """The :class:`jaxlib.xla_extension.Buffer` that backs the + underlying data array.""" return self._data.device_buffer @property @@ -931,6 +947,7 @@ def ndim(self) -> Shape: @property def shape(self) -> BlockShape: """Tuple of component shapes.""" + return self._aval.shapes @property @@ -940,19 +957,19 @@ def split(self) -> Tuple[JaxArray, ...]: return tuple(self[k] for k in range(self.num_blocks)) def conj(self) -> BlockArray: - """Returns a :class:`.BlockArray` with complex-conjugated elements.""" + """Return a :class:`.BlockArray` with complex-conjugated elements.""" # Much faster than BlockArray.array([_.conj() for _ in self.blocks]) return BlockArray.array_from_flattened(self.ravel().conj(), self.shape) @property def real(self) -> BlockArray: - """Returns a :class:`.BlockArray` with the real part of this array.""" + """Return a :class:`.BlockArray` with the real part of this array.""" return BlockArray.array_from_flattened(self.ravel().real, self.shape) @property def imag(self) -> BlockArray: - """Returns a :class:`.BlockArray` with the imaginary part of this array.""" + """Return a :class:`.BlockArray` with the imaginary part of this array.""" return BlockArray.array_from_flattened(self.ravel().imag, self.shape) @classmethod @@ -962,9 +979,11 @@ def array( """Construct a :class:`.BlockArray` from a list or tuple of existing array-like. Args: - alst : Initializers for array components. - Can be :class:`numpy.ndarray` or :class:`jax.interpreters.xla.DeviceArray` - dtype : Data type of array. If none, dtype is derived from dtype of initializers + alst: Initializers for array components. + Can be :class:`numpy.ndarray` or + :class:`jax.interpreters.xla.DeviceArray` + dtype: Data type of array. If none, dtype is derived from + dtype of initializers Returns: :class:`.BlockArray` initialized from `alst` tuple @@ -1032,10 +1051,12 @@ def ones(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: Args: shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. Default is `numpy.float32`. + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. Returns: - :class:`.BlockArray` of ones with the given component shapes and dtype + :class:`.BlockArray` of ones with the given component shapes + and dtype. """ _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) data_ravel = jnp.ones(_aval.size, dtype=dtype) @@ -1047,11 +1068,13 @@ def zeros(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. Args: - shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. Default is `numpy.float32`. + shape_tuple: Tuple of shapes for component blocks. + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. Returns: - :class:`.BlockArray` of zeros with the given component shapes and dtype + :class:`.BlockArray` of zeros with the given component shapes + and dtype. """ _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) data_ravel = jnp.zeros(_aval.size, dtype=dtype) @@ -1062,14 +1085,17 @@ def empty(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray """ Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - Note: like :func:`jax.numpy.empty`, this does not return an uninitalized array. + Note: like :func:`jax.numpy.empty`, this does not return an + uninitalized array. Args: shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. Default is `numpy.float32`. + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. Returns: - :class:`.BlockArray` of zeros with the given component shapes and dtype. + :class:`.BlockArray` of zeros with the given component shapes + and dtype. """ _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) data_ravel = jnp.empty(_aval.size, dtype=dtype) @@ -1088,13 +1114,13 @@ def full( Args: shape_tuple: Tuple of shapes for component blocks. - fill_value: Fill value + fill_value: Fill value dtype: Desired data-type for the BlockArray. The default, - None, means `np.array(fill_value).dtype`. + None, means `np.array(fill_value).dtype`. Returns: - :class:`.BlockArray` with the given component shapes and dtype and all entries - equal to `fill_value`. + :class:`.BlockArray` with the given component shapes and + dtype and all entries equal to `fill_value`. """ if dtype is None: dtype = np.asarray(fill_value).dtype @@ -1104,15 +1130,16 @@ def full( return cls(_aval, data_ravel) def copy(self) -> BlockArray: - """Returns a copy of this :class:`.BlockArray`. + """Return a copy of this :class:`.BlockArray`. This method is not implemented for BlockArray. See Also: - :meth:`.to_numpy`: Converts a :class:`.BlockArray` into a flattened NumPy array. + :meth:`.to_numpy`: Convert a :class:`.BlockArray` into a + flattened NumPy array. """ - # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed - # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray + # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed + # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray # .to_numpy() method returns a flattened ndarray. # # This method may be implemented in the future if jax DeviceArray .copy() is modified to @@ -1120,7 +1147,7 @@ def copy(self) -> BlockArray: raise NotImplementedError def to_numpy(self) -> np.ndarray: - """Returns a :class:`numpy.ndarray` containing the flattened form of this + """Return a :class:`numpy.ndarray` containing the flattened form of this :class:`.BlockArray`.""" if isinstance(self._data, DeviceArray): @@ -1130,24 +1157,24 @@ def to_numpy(self) -> np.ndarray: return host_arr def blockidx(self, idx: int) -> jax._src.ops.scatter._Indexable: - """Returns :class:`jax.ops.index` for a given component block. + """Return :class:`jax.ops.index` for a given component block. Args: - idx: Desired block index + idx: Desired block index. Returns: - :class:`jax.ops.index` pointing to desired block + :class:`jax.ops.index` pointing to desired block. """ return jax.ops.index[self.bndpos[idx] : self.bndpos[idx + 1]] def ravel(self) -> JaxArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - Note that a copy, rather than a view, of the underlying array is returned. - This is consistent with :func:`jax.numpy.ravel`. + Note that a copy, rather than a view, of the underlying array is + returned. This is consistent with :func:`jax.numpy.ravel`. Returns: - Copy of underlying flattened array + Copy of underlying flattened array. """ return self._data[:] @@ -1155,11 +1182,11 @@ def ravel(self) -> JaxArray: def flatten(self) -> JaxArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - Note that a copy, rather than a view, of the underlying array is returned. - This is consistent with :func:`jax.numpy.ravel`. + Note that a copy, rather than a view, of the underlying array is + returned. This is consistent with :func:`jax.numpy.ravel`. Returns: - Copy of underlying flattened array + Copy of underlying flattened array. """ return self._data[:] @@ -1206,7 +1233,7 @@ def _block_array_device_put_handler(a, device): ## Handlers to use jax.device_put on BlockArray def _block_array_tree_flatten(block_arr): - """Flattens a :class:`.BlockArray` pytree. + """Flatten a :class:`.BlockArray` pytree. See :func:`jax.tree_util.tree_flatten`. @@ -1214,7 +1241,7 @@ def _block_array_tree_flatten(block_arr): block_arr (:class:`.BlockArray`): :class:`.BlockArray` to flatten Returns: - children (tuple): :class:`.BlockArray` leaves + children (tuple): :class:`.BlockArray` leaves. aux_data (tuple): Extra metadata used to reconstruct BlockArray. """ @@ -1223,16 +1250,16 @@ def _block_array_tree_flatten(block_arr): def _block_array_tree_unflatten(aux_data, children): - """Constructs a :class:`.BlockArray` from a flattened pytree. + """Construct a :class:`.BlockArray` from a flattened pytree. See jax.tree_utils.tree_unflatten Args: - aux_data (tuple): Metadata needed to construct block array - children (tuple): Contains block array elements + aux_data (tuple): Metadata needed to construct block array. + children (tuple): Contains block array elements. Returns: - block_arr: Constructed :class:`.BlockArray` + block_arr: Constructed :class:`.BlockArray`. """ return BlockArray(aux_data, children[0]) @@ -1245,16 +1272,16 @@ class _BlockArrayIndexUpdateHelper: """The helper class for the `at` property to call indexed update functions. The `at` property is syntactic sugar for calling the indexed update - functions as is done in jax. The index must be of the form [ibk] or [ibk,idx], - where `ibk` is the index of the block to be updated, and `idx` is a - general index of the elements to be updated in that block. + functions as is done in jax. The index must be of the form [ibk] or + [ibk,idx], where `ibk` is the index of the block to be updated, and + `idx` is a general index of the elements to be updated in that block. In particular: - ``x = x.at[ibk].set(y)`` is an equivalent of ``x[ibk] = y``. - ``x = x.at[ibk,idx].set(y)`` is an equivalent of ``x[ibk,idx] = y``. - The methods ``set, add, multiply, divide, power, maximum, minimum`` are supported. - + The methods ``set, add, multiply, divide, power, maximum, minimum`` + are supported. """ __slots__ = ("_block_array",) @@ -1275,13 +1302,11 @@ def __repr__(self): class _BlockArrayIndexUpdateRef: """Helper object to call indexed update functions for an (advanced) index. - This object references a source block array and a specific indexer, with the - first integer specifying the block being updated, and rest being the indexer - into the array of that block. - Methods on this object return copies of the source block array that have - been modified at the positions specified by the indexer in the given block. - - + This object references a source block array and a specific indexer, + with the first integer specifying the block being updated, and rest + being the indexer into the array of that block. Methods on this + object return copies of the source block array that have been + modified at the positions specified by the indexer in the given block. """ __slots__ = ("_block_array", "bk_index", "index") @@ -1320,7 +1345,7 @@ def _index_wrapper(self, func_str, values): def set(self, values): """Pure equivalent of ``x[idx] = y``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = y``. See :mod:`jax.ops` for details. @@ -1330,7 +1355,7 @@ def set(self, values): def add(self, values): """Pure equivalent of ``x[idx] += y``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] += y``. See :mod:`jax.ops` for details. @@ -1340,7 +1365,7 @@ def add(self, values): def multiply(self, values): """Pure equivalent of ``x[idx] *= y``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] *= y``. See :mod:`jax.ops` for details. @@ -1350,7 +1375,7 @@ def multiply(self, values): def divide(self, values): """Pure equivalent of ``x[idx] /= y``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] /= y``. See :mod:`jax.ops` for details. @@ -1360,7 +1385,7 @@ def divide(self, values): def power(self, values): """Pure equivalent of ``x[idx] **= y``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] **= y``. See :mod:`jax.ops` for details. @@ -1370,7 +1395,7 @@ def power(self, values): def min(self, values): """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. See :mod:`jax.ops` for details. @@ -1380,7 +1405,7 @@ def min(self, values): def max(self, values): """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. - Returns the value of ``x`` that would result from the NumPy-style + Return the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. See :mod:`jax.ops` for details. diff --git a/scico/data/__init__.py b/scico/data/__init__.py index 81581af59..ef30222ef 100644 --- a/scico/data/__init__.py +++ b/scico/data/__init__.py @@ -22,13 +22,13 @@ def _imread(filename: str, path: str = None, asfloat: bool = False) -> DeviceArr """Read an image from disk. Args: - str: Base filename (i.e. without path) of image file - path: Path to directory containing the image file + str: Base filename (i.e. without path) of image file. + path: Path to directory containing the image file. asfloat: Flag indicating whether the returned image should be - converted to float32 dtype with a range [0, 1] + converted to float32 dtype with a range [0, 1]. Returns: - DeviceArray: image data array + DeviceArray: image data array. """ if path is None: @@ -44,10 +44,10 @@ def kodim23(asfloat: bool = False) -> DeviceArray: Args: asfloat: Flag indicating whether the returned image should be - converted to float32 dtype with a range [0, 1] + converted to float32 dtype with a range [0, 1]. Returns: - DeviceArray: image data array + DeviceArray: image data array. """ return _imread("kodim23.png", asfloat=asfloat) @@ -57,10 +57,10 @@ def _flax_data_path(filename: str) -> str: """Get the full filename of a flax data file. Args: - str: Base filename (i.e. without path) of data file + str: Base filename (i.e. without path) of data file. Returns: - str: Full filename, with path, of data file + str: Full filename, with path, of data file. """ return os.path.join(os.path.dirname(__file__), "flax", filename) diff --git a/scico/diagnostics.py b/scico/diagnostics.py index df4b280cc..802269596 100644 --- a/scico/diagnostics.py +++ b/scico/diagnostics.py @@ -15,7 +15,11 @@ class IterationStats: - """Display and record statistics related to convergence of iterative algorithms""" + """Display and record iterative algorithms statistics. + + Display and record statistics related to convergence of iterative + algorithms. + """ def __init__( self, @@ -29,28 +33,29 @@ def __init__( field order is retained) specifying field names for each value to be inserted and a corresponding format string for when it is displayed. When inserted values are printed in tabular form, the - field lengths are taken as the maxima of the header string lengths - and the field lengths embedded in the format strings (if specified). - For best results, the field lengths should be manually specified based - on knowledge of the ranges of values that may be encountered. For - example, for a '%e' format string, the specified field length should - be at least the precision (e.g. '%.2e' specifies a precision of 2 - places) plus 6 when only positive values may encountered, and plus 7 - when negative values may be encountered. + field lengths are taken as the maxima of the header string + lengths and the field lengths embedded in the format strings (if + specified). For best results, the field lengths should be + manually specified based on knowledge of the ranges of values + that may be encountered. For example, for a '%e' format string, + the specified field length should be at least the precision (e.g. + '%.2e' specifies a precision of 2 places) plus 6 when only + positive values may encountered, and plus 7 when negative values + may be encountered. Args: - fields: A dictionary associating field names with format strings for - displaying the corresponding values. + fields: A dictionary associating field names with format + strings for displaying the corresponding values. ident: A dictionary associating field names. - with corresponding valid identifiers for use within the namedtuple used to - record results. Defaults to None. - display : Flag indicating whether results should be printed to stdout. - Defaults to ``False``. - colsep : Number of spaces seperating fields in displayed tables. - Defaults to 2. + with corresponding valid identifiers for use within the + namedtuple used to record results. Defaults to ``None``. + display: Flag indicating whether results should be printed + to stdout. Defaults to ``False``. + colsep: Number of spaces seperating fields in displayed + tables. Defaults to 2. Raises: - TypeError: Description + TypeError: If the ``fields`` parameter is not a dict. """ # Parameter fields must be specified as an OrderedDict to ensure @@ -130,10 +135,10 @@ def __init__( def insert(self, values: Union[List, Tuple]): """ - Insert a list of values for a single iteration + Insert a list of values for a single iteration. Args: - values : Statistics for a single iteration + values: Statistics for a single iteration. """ self.iterations.append(self.IterTuple(*values)) @@ -146,15 +151,16 @@ def insert(self, values: Union[List, Tuple]): def history(self, transpose: bool = False): """ - Retrieve record of all inserted iterations + Retrieve record of all inserted iterations. Args: - transpose: Flag indicating whether results - should be returned in "transposed" form, i.e. as a namedtuple of lists - rather than a list of namedtuples. Default: False + transpose: Flag indicating whether results should be returned + in "transposed" form, i.e. as a namedtuple of lists + rather than a list of namedtuples. Default: False. Returns: - list of namedtuple or namedtuple of lists: Record of all inserted iterations + list of namedtuple or namedtuple of lists: Record of all + inserted iterations. """ if transpose: diff --git a/scico/flax.py b/scico/flax.py index 00906f6d9..99bf17f59 100644 --- a/scico/flax.py +++ b/scico/flax.py @@ -26,13 +26,13 @@ class ConvBNBlock(nn.Module): r"""Define convolution and batch normalization Flax block. Attributes: - num_filters : number of filters in the convolutional layer + num_filters: Number of filters in the convolutional layer of the block. - conv : class of convolution to apply. - norm : class of batch normalization to apply. - act : class of activation function to apply. - kernel_size : size of the convolution filters. Default: (3, 3). - stride : convolution strides. Default: (1, 1) + conv: Class of convolution to apply. + norm: Class of batch normalization to apply. + act: Class of activation function to apply. + kernel_size: Size of the convolution filters. Default: (3, 3). + stride: Convolution strides. Default: (1, 1). """ num_filters: int @@ -71,13 +71,13 @@ class DnCNNNet(nn.Module): architecture for denoising described in :cite:`zhang-2017-dncnn`. Attributes: - depth : number of layers in the neural network. - channels : number of channels of input tensor. - num_filters : number of filters in the convolutional layers. - kernel_size : size of the convolution filters. Default: (3, 3). - strides : convolution strides. Default: (1, 1). - dtype : . Default: `jnp.float32`. - act : class of activation function to apply. Default: `nn.relu`. + depth: Number of layers in the neural network. + channels: Number of channels of input tensor. + num_filters: Number of filters in the convolutional layers. + kernel_size: Size of the convolution filters. Default: (3, 3). + strides: Convolution strides. Default: (1, 1). + dtype: Output dtype. Default: `jnp.float32`. + act: Class of activation function to apply. Default: `nn.relu`. """ depth: int @@ -148,8 +148,7 @@ def load_weights(filename: str): """Load trained model weights. Args: - filename : name of file where parameters for trained model - have been stored. + filename: Name of file containing parameters for trained model. """ with open(filename, "rb") as data_file: bytes_input = data_file.read() diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index e704fd540..42742eeb1 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -5,7 +5,7 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Functionals and functionals classes""" +"""Functionals and functionals classes.""" import sys diff --git a/scico/functional/_denoiser.py b/scico/functional/_denoiser.py index 0f5f03b79..1947301d4 100644 --- a/scico/functional/_denoiser.py +++ b/scico/functional/_denoiser.py @@ -28,10 +28,12 @@ class BM3D(Functional): - r"""Functional whose prox applies the BM3D denoising algorithm :cite:`dabov-2008-image`. + r"""Functional whose prox applies the BM3D denoising algorithm. - The BM3D algorithm is computed using the `code `__ released - with :cite:`makinen-2019-exact`. + A pseudo-function that has the BM3D algorithm :cite:`dabov-2008-image` + as its proximal operator. BM3D denoising is performed using the + `code `__ released with + :cite:`makinen-2019-exact`. """ has_eval = False @@ -42,7 +44,8 @@ def __init__(self, is_rgb: Optional[bool] = False): r"""Initialize a :class:`BM3D` object. Args: - is_rgb : Flag indicating use of BM3D with a color transform. Default: False. + is_rgb: Flag indicating use of BM3D with a color transform. + Default: False. """ if is_rgb is True: @@ -58,8 +61,8 @@ def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: r"""Apply BM3D denoiser with noise level ``lam``. Args: - x : input image. - lam : noise level. + x: input image. + lam: noise level. Returns: BM3D denoised output. @@ -111,8 +114,9 @@ def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: class DnCNN(FlaxMap): - """Flax implementation of the DnCNN denoiser :cite:`zhang-2017-dncnn`. + """Flax implementation of the DnCNN denoiser. + A flax implementation of the DnCNN denoiser :cite:`zhang-2017-dncnn`. Note that :class:`.flax.DnCNNNet` represents an untrained form of the generic DnCNN CNN structure, while this class represents a trained form with six or seventeen layers. @@ -126,7 +130,7 @@ def __init__(self, variant: Optional[str] = "6M"): of each channel. Args: - variant : Identify the DnCNN model to be used. Options are + variant: Identify the DnCNN model to be used. Options are '6L', '6M' (default), '6H', '17L', '17M', and '17H', where the integer indicates the number of layers in the network, and the postfix indicates the training noise @@ -151,8 +155,8 @@ def prox(self, x: JaxArray, lam: float = 1) -> JaxArray: the output. Args: - x : input. - lam : noise estimate (ignored). + x: input. + lam: noise estimate (ignored). Returns: DnCNN denoised output. diff --git a/scico/functional/_flax.py b/scico/functional/_flax.py index 09ff253dc..cd8664036 100644 --- a/scico/functional/_flax.py +++ b/scico/functional/_flax.py @@ -21,7 +21,7 @@ class FlaxMap(Functional): - r"""Functional whose prox applies a trained Flax model.""" + r"""Functional whose prox applies a trained flax model.""" has_eval = False has_prox = True @@ -31,8 +31,8 @@ def __init__(self, model: Callable[..., nn.Module], variables: PyTree): r"""Initialize a :class:`FlaxMap` object. Args: - model : Flax model to apply. - variables : Parameters and batch stats of trained model. + model: Flax model to apply. + variables: Parameters and batch stats of trained model. """ self.model = model self.variables = variables @@ -45,8 +45,8 @@ def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: the output. Args: - x : input. - lam : noise estimate (ignored). + x: input. + lam: noise estimate (ignored). Returns: Output of flax model. diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index d3b7cb64b..6ccd63557 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -5,7 +5,7 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"Functional base class." "" +"""Functional base class.""" import warnings from typing import List, Optional, Union @@ -25,8 +25,9 @@ class Functional: r"""Base class for functionals. - A functional maps an :code:`array-like` to a scalar; abstractly, a functional is - a mapping from :math:`\mathbb{R}^n` or :math:`\mathbb{C}^n` to :math:`\mathbb{R}`. + A functional maps an :code:`array-like` to a scalar; abstractly, a + functional is a mapping from :math:`\mathbb{R}^n` or + :math:`\mathbb{C}^n` to :math:`\mathbb{R}`. """ #: True if this functional can be evaluated, False otherwise. @@ -67,7 +68,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: r"""Evaluate this functional at point :math:`\mb{x}`. Args: - x : Point at which to evaluate this functional. + x: Point at which to evaluate this functional. """ if not self.has_eval: @@ -85,16 +86,19 @@ def prox( `x` = :math:`\mb{x}` .. math:: - \mathrm{prox}_{\lambda f}(\mb{x}) = \argmin_{\mb{v}} \frac{1}{2} - \norm{\mb{x} - \mb{v}}_2^2 + \lambda \ \mathrm{f}(\mb{v}) \;, + \mathrm{prox}_{\lambda f}(\mb{x}) = \argmin_{\mb{v}} + \frac{1}{2} \norm{\mb{x} - \mb{v}}_2^2 + \lambda + \ \mathrm{f}(\mb{v}) \;, - where :math:`f(\mb{v})` represents this functional evaluated at :math:`\mb{v}`. + where :math:`f(\mb{v})` represents this functional evaluated at + :math:`\mb{v}`. Args: - x : Point at which to evaluate prox function. - lam : Proximal parameter :math:`\lambda`. - kwargs : Additional arguments that may be used by derived - classes. These include ``v0``, an initial guess for the minimizer. + x: Point at which to evaluate prox function. + lam: Proximal parameter :math:`\lambda`. + kwargs: Additional arguments that may be used by derived + classes. These include ``v0``, an initial guess for the + minimizer. """ if not self.has_prox: @@ -109,28 +113,29 @@ def conj_prox( Evaluate scaled proximal operator of convex conjugate (Fenchel conjugate) of this functional, with scaling - `lam` = :math:`\lambda`, and evaluated at point `x` = :math:`\mb{x}`. - Denoting this functional by :math:`f` and its convex conjugate by - :math:`f^*`, the proximal operator of :math:`f^*` is computed as - follows by exploiting the extended Moreau decomposition (see - Sec. 6.6 of :cite:`beck-2017-first`) + `lam` = :math:`\lambda`, and evaluated at point + `x` = :math:`\mb{x}`. Denoting this functional by :math:`f` and + its convex conjugate by :math:`f^*`, the proximal operator of + :math:`f^*` is computed as follows by exploiting the extended + Moreau decomposition (see Sec. 6.6 of :cite:`beck-2017-first`) .. math:: \mathrm{prox}_{\lambda f^*}(\mb{x}) = \mb{x} - \lambda \mathrm{prox}_{\lambda^{-1} f}(\mb{x / \lambda}) \;. Args: - x : Point at which to evaluate prox function. - lam : Proximal parameter :math:`\lambda`. - kwargs : additional keyword args, passed directly to ``self.prox``. + x: Point at which to evaluate prox function. + lam: Proximal parameter :math:`\lambda`. + kwargs: Additional keyword args, passed directly to + ``self.prox``. """ return x - lam * self.prox(x / lam, 1.0 / lam, **kwargs) def grad(self, x: Union[JaxArray, BlockArray]): - r"""Evaluates the gradient of this functional at point :math:`\mb{x}`. + r"""Evaluates the gradient of this functional at :math:`\mb{x}`. Args: - x : Point at which to evaluate gradient. + x: Point at which to evaluate gradient. """ if not self.is_smooth: # could be True, False, or None warnings.warn("This functional isn't smooth!", stacklevel=2) @@ -189,16 +194,16 @@ def prox( class SeparableFunctional(Functional): r"""A functional that is separable in its arguments. - A separable functional :math:`f : \mathbb{C}^N \to \mathbb{R}` can be written as the sum - of functionals :math:`f_i : \mathbb{C}^{N_i} \to \mathbb{R}` with :math:`\sum_i N_i = N`. - In particular, + A separable functional :math:`f : \mathbb{C}^N \to \mathbb{R}` can + be written as the sum of functionals :math:`f_i : \mathbb{C}^{N_i} + \to \mathbb{R}` with :math:`\sum_i N_i = N`. In particular, .. math:: - f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots + f_N(\mb{x}_N) - - A :class:`SeparableFunctional` accepts a :class:`.BlockArray` and is separable - in the block components. + f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots + + f_N(\mb{x}_N) \;. + A :class:`SeparableFunctional` accepts a :class:`.BlockArray` and is + separable in the block components. """ def __init__(self, functional_list: List[Functional]): @@ -221,13 +226,15 @@ def __call__(self, x: BlockArray) -> float: return snp.sum(snp.array([fi(xi) for fi, xi in zip(self.functional_list, x)])) else: raise ValueError( - f"Number of blocks in x, {len(x.shape)}, and length of functional_list, {len(self.functional_list)}, do not match" + f"Number of blocks in x, {len(x.shape)}, and length of functional_list, " + f"{len(self.functional_list)}, do not match" ) def prox(self, x: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: r"""Evaluate proximal operator of the separable functional. - Evaluate proximal operator of the separable functional (see Theorem 6.6 of :cite:`beck-2017-first`). + Evaluate proximal operator of the separable functional (see + Theorem 6.6 of :cite:`beck-2017-first`). .. math:: \mathrm{prox}_f(\mb{x}, \lambda) @@ -236,17 +243,18 @@ def prox(self, x: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: \mathrm{prox}_{f_1}(\mb{x}_1, \lambda) \\ \vdots \\ \mathrm{prox}_{f_N}(\mb{x}_N, \lambda) \\ - \end{bmatrix} + \end{bmatrix} \;. Args: - x : Input array :math:`\mb{x}` - lam : Proximal parameter :math:`\lambda` + x: Input array :math:`\mb{x}`. + lam: Proximal parameter :math:`\lambda`. """ if len(x.shape) == len(self.functional_list): return BlockArray.array([fi.prox(xi, lam) for fi, xi in zip(self.functional_list, x)]) else: raise ValueError( - f"Number of blocks in x, {len(x.shape)}, and length of functional_list, {len(self.functional_list)}, do not match" + f"Number of blocks in x, {len(x.shape)}, and length of functional_list, " + f"{len(self.functional_list)}, do not match" ) diff --git a/scico/functional/_indicator.py b/scico/functional/_indicator.py index 3f8cf8506..25bf474d2 100644 --- a/scico/functional/_indicator.py +++ b/scico/functional/_indicator.py @@ -24,13 +24,14 @@ class NonNegativeIndicator(Functional): r"""Indicator function for non-negative orthant. - Returns 0 if all elements of input array-like are non-negative, and inf otherwise. + Returns 0 if all elements of input array-like are non-negative, and + inf otherwise .. math:: I(\mb{x}) = \begin{cases} - 0, & \text{if } x_i \geq 0 \text{ for each } i \\ - \infty, & \text{else} - \end{cases} \; + 0 & \text{if } x_i \geq 0 \text{ for each } i \\ + \infty & \text{else} \;. + \end{cases} """ @@ -49,18 +50,20 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: def prox( self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: - r"""Evaluate proximal operator of indicator over non-negative orthant: + r"""Proximal operator of indicator over non-negative orthant. + + Proximal operator of indicator over non-negative orthant .. math:: [\mathrm{prox}(\mb{x}, \lambda)]_i = \begin{cases} - x_i, & \text{if } x_i \geq 0 \\ - 0, & \text{else}. + x_i & \text{if } x_i \geq 0 \\ + 0 & \text{else} \;. \end{cases} Args: - x : Input array :math:`\mb{x}` - lam : Proximal parameter :math:`\lambda` + x: Input array :math:`\mb{x}`. + lam: Proximal parameter :math:`\lambda`. """ return snp.maximum(x, 0) @@ -68,15 +71,16 @@ def prox( class L2BallIndicator(Functional): r"""Indicator function for :math:`\ell_2` ball of given radius. + Indicator function for :math:`\ell_2` ball of given radius + .. math:: I(\mb{x}) = \begin{cases} - 0, & \text{if } \norm{\mb{x}}_2 \leq \mathrm{radius} \\ - \infty, & \text{else} - \end{cases} \; + 0 & \text{if } \norm{\mb{x}}_2 \leq \mathrm{radius} \\ + \infty & \text{else} \;. + \end{cases} Attributes: - radius : Radius of :math:`\ell_2` ball - + radius: Radius of :math:`\ell_2` ball. """ has_eval = True @@ -87,7 +91,7 @@ def __init__(self, radius: float = 1): r"""Initialize a :class:`L2BallIndicator` object. Args: - radius : Radius of :math:`\ell_2` ball. Default: 1. + radius: Radius of :math:`\ell_2` ball. Default: 1. """ self.radius = radius super().__init__() @@ -100,10 +104,12 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: def prox( self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: - r"""Evaluate proximal operator of indicator over :math:`\ell_2` ball: + r"""Proximal operator of indicator over :math:`\ell_2` ball. - .. math:: - \mathrm{prox}(\mb{x}, \lambda) = \mathrm{radius} \frac{\mb{x}}{\norm{\mb{x}}_2} + Proximal operator of indicator over :math:`\ell_2` ball + .. math:: + \mathrm{prox}(\mb{x}, \lambda) = \mathrm{radius} + \frac{\mb{x}}{\norm{\mb{x}}_2} \;. """ return self.radius * x / norm(x) diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 44a204a5c..bfded542e 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -24,8 +24,10 @@ class L0Norm(Functional): - r"""The :math:`\ell_0` 'norm'. Calculates the number of non-zero elements in an - array-like.""" + r"""The :math:`\ell_0` 'norm'. + + Counts the number of non-zero elements in an array-like. + """ has_eval = True has_prox = True @@ -39,30 +41,33 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: def prox( x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: - r"""Evaluate proximal operator of :math:`\ell_0` norm + r"""Evaluate proximal operator of :math:`\ell_0` norm. + + Evaluate proximal operator of :math:`\ell_0` norm .. math:: \mathrm{prox}(\mb{x}, \lambda) = \begin{cases} - \mb{x}, & \text{if } \abs{\mb{x}} \geq \lambda \\ - 0, & \text{else} + \mb{x} & \text{if } \abs{\mb{x}} \geq \lambda \\ + 0 & \text{else} \;. \end{cases} Args: - x : Input array :math:`\mb{x}` - lam : Thresholding parameter :math:`\lambda` - + x: Input array :math:`\mb{x}`. + lam: Thresholding parameter :math:`\lambda`. """ return snp.where(snp.abs(x) >= lam, x, 0) class L1Norm(Functional): - r"""The :math:`\ell_1` norm. Computes + r"""The :math:`\ell_1` norm. + + Computes .. math:: - \norm{\mb{x}}_1 = \sum_i \abs{x_i}^2 + \norm{\mb{x}}_1 = \sum_i \abs{x_i}^2 \;. """ has_eval = True @@ -74,21 +79,25 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: @staticmethod def prox(x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs) -> JaxArray: - r"""Evaluate proximal operator of :math:`\ell_1` norm + r"""Evaluate proximal operator of :math:`\ell_1` norm. + + Evaluate proximal operator of :math:`\ell_1` norm .. math:: \mathrm{prox}(\mb{x}, \lambda)_i = \mathrm{sign}(\mb{x}_i) - (\abs{\mb{x}_i} - \lambda)_+ \; + (\abs{\mb{x}_i} - \lambda)_+ \;, where .. math:: (x)_+ = \begin{cases} - x, & \text{if } x \geq 0 \\ - 0, & \text{else} - \end{cases} \; - + x & \text{if } x \geq 0 \\ + 0 & \text{else} \;. + \end{cases} + Args: + x: Input array :math:`\mb{x}`. + lam: Thresholding parameter :math:`\lambda`. """ tmp = snp.abs(x) - lam tmp = 0.5 * (tmp + snp.abs(tmp)) @@ -102,10 +111,10 @@ def prox(x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs) -> JaxArray class SquaredL2Norm(Functional): r"""Squared :math:`\ell_2` norm. - .. math:: - \norm{\mb{x}}^2_2 = \sum_i \abs{x_i}^2 - + Squared :math:`\ell_2` norm + .. math:: + \norm{\mb{x}}^2_2 = \sum_i \abs{x_i}^2 \;. """ has_eval = True @@ -120,14 +129,17 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: def prox( self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: - r"""Evaluate proximal operator of squared :math:`\ell_2` norm: + r"""Evaluate proximal operator of squared :math:`\ell_2` norm. + + Evaluate proximal operator of squared :math:`\ell_2` norm .. math:: - \mathrm{prox}(\mb{x}, \lambda) = \frac{\mb{x}}{1 + 2 \lambda} + \mathrm{prox}(\mb{x}, \lambda) = \frac{\mb{x}}{1 + + 2 \lambda} \;. Args: - x : Input array :math:`\mb{x}` - lam : Proximal parameter :math:`\lambda` + x: Input array :math:`\mb{x}`. + lam: Proximal parameter :math:`\lambda`. """ return x / (1.0 + 2.0 * lam) @@ -136,8 +148,7 @@ class L2Norm(Functional): r""":math:`\ell_2` norm. .. math:: - \norm{\mb{x}}_2 = \sqrt{\sum_i \abs{x_i}^2} - + \norm{\mb{x}}_2 = \sqrt{\sum_i \abs{x_i}^2} \;. """ has_eval = True @@ -150,22 +161,25 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: def prox( self, x: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: - r"""Evaluate proximal operator of :math:`\ell_2` norm: + r"""Evaluate proximal operator of :math:`\ell_2` norm. + + Evaluate proximal operator of :math:`\ell_2` norm .. math:: - \mathrm{prox}(\mb{x}, \lambda) = \mb{x} \right(1 - \frac{\lambda}{\norm{x}_2}\left)_+, + \mathrm{prox}(\mb{x}, \lambda) = \mb{x} \left(1 - + \frac{\lambda}{\norm{x}_2} \right)_+ \;, where .. math:: (x)_+ = \begin{cases} - x, & \text{if } x \geq 0 \\ - 0, & \text{else} - \end{cases} \; + x & \text{if } x \geq 0 \\ + 0 & \text{else} \;. + \end{cases} Args: - x : Input array :math:`\mb{x}` - lam : Proximal parameter :math:`\lambda` + x: Input array :math:`\mb{x}`. + lam: Proximal parameter :math:`\lambda`. """ norm_x = norm(x) if norm_x == 0: @@ -180,12 +194,15 @@ class L21Norm(Functional): For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, .. math:: - \norm{\mb{A}}_{2,1} = \sum_{n=1}^N \sqrt{\sum_{m=1}^M \abs{A_{m,n}}^2}. + \norm{\mb{A}}_{2,1} = \sum_{n=1}^N \sqrt{\sum_{m=1}^M + \abs{A_{m,n}}^2} \;. - The norm generalizes to more dimensions by first computing the :math:`\ell_2` norm along - a single (user-specified) dimension, followed by a sum over all remaining dimensions. + The norm generalizes to more dimensions by first computing the + :math:`\ell_2` norm along a single (user-specified) dimension, + followed by a sum over all remaining dimensions. - For `BlockArray` inputs, the :math:`\ell_2` norm follows the reduction rules described in :class:`BlockArray`. + For `BlockArray` inputs, the :math:`\ell_2` norm follows the + reduction rules described in :class:`BlockArray`. A typical use case is computing the isotropic total variation norm. """ @@ -213,16 +230,21 @@ def prox( In two dimensions, .. math:: - \mathrm{prox}(\mb{A}, \lambda)_{:, n} = \frac{\mb{A}_{:, n}}{\|\mb{A}_{:, n}\|_2} - (\|\mb{A}_{:, n}\|_2 - \lambda)_+ \; + \mathrm{prox}(\mb{A}, \lambda)_{:, n} = + \frac{\mb{A}_{:, n}}{\|\mb{A}_{:, n}\|_2} + (\|\mb{A}_{:, n}\|_2 - \lambda)_+ \;, where .. math:: (x)_+ = \begin{cases} - x, & \text{if } x \geq 0 \\ - 0, & \text{else}. - \end{cases} \; + x & \text{if } x \geq 0 \\ + 0 & \text{else} \;. + \end{cases} + + Args: + x: Input array :math:`\mb{x}`. + lam: Proximal parameter :math:`\lambda`. """ length = norm(x, axis=self.l2_axis, keepdims=True) diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index a3a5c3a9d..5bdc0e6dd 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -34,20 +34,20 @@ class CircularConvolve(LinearOperator): r"""A circular convolution linear operator. - This linear operator implements circular, n-dimensional convolution via - pointwise multiplication in the DFT domain. In its simplest form, it - implements a single convolution and can be represented by linear operator - :math:`H` such that + This linear operator implements circular, n-dimensional convolution + via pointwise multiplication in the DFT domain. In its simplest form, + it implements a single convolution and can be represented by linear + operator :math:`H` such that .. math:: H \mb{x} = \mb{h} \ast \mb{x} \;, where :math:`\mb{h}` is a user-defined filter. - More complex forms, corresponding to the case where either the input (as - represented by parameter `input_shape`) or filter (parameter `h`) have - additional axes that are not involved in the convolution are also - supported. These follow numpy broadcasting rules. For example: + More complex forms, corresponding to the case where either the input + (as represented by parameter `input_shape`) or filter (parameter `h`) + have additional axes that are not involved in the convolution are + also supported. These follow numpy broadcasting rules. For example: Additional axes in the input :math:`\mb{x}` and not in :math:`\mb{h}` corresponds to the operation @@ -57,15 +57,18 @@ class CircularConvolve(LinearOperator): 0 & H' & \ldots\\ \vdots & \vdots & \ddots \end{array} \right) - \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} \right) \;. + \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} + \right) \;. - Additional axes in :math:`\mb{h}` corresponds to multiple filters, which - will be denoted by :math:`\{\mb{h}_m\}`, with corresponding individual - linear operations being denoted by :math:`h_m \mb{x}_m = \mb{h}_m \ast - \mb{x}_m`. The full linear operator can then be represented as + Additional axes in :math:`\mb{h}` corresponds to multiple filters, + which will be denoted by :math:`\{\mb{h}_m\}`, with corresponding + individual linear operations being denoted by :math:`h_m \mb{x}_m = + \mb{h}_m \ast \mb{x}_m`. The full linear operator can then be + represented as .. math:: - H \mb{x} = \left( \begin{array}{c} H_0\\ H_1\\ \vdots \end{array} \right) \mb{x} + H \mb{x} = \left( \begin{array}{c} H_0\\ H_1\\ \vdots \end{array} + \right) \mb{x} \;. if the input is singleton, and as @@ -74,7 +77,8 @@ class CircularConvolve(LinearOperator): 0 & H_1 & \ldots\\ \vdots & \vdots & \ddots \end{array} \right) - \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} \right) \; + \left( \begin{array}{c} \mb{x}_0\\ \mb{x}_1\\ \vdots \end{array} + \right) otherwise. """ @@ -94,10 +98,14 @@ def __init__( Args: h: Array of filters. input_shape: Shape of input array. - ndims: Number of (trailing) dimensions of the input and `h` involved in the convolution. Defaults to the number of dimensions in the input. - input_dtype: `dtype` for input argument. Defaults to `float32`. - h_is_dft: Flag indicating whether ``h`` is in the DFT domain - jit: If `True`, jit the evaluation, adjoint, and gram functions of the LinearOperator + ndims: Number of (trailing) dimensions of the input and `h` + involved in the convolution. Defaults to the number of + dimensions in the input. + input_dtype: `dtype` for input argument. Defaults to + `float32`. + h_is_dft: Flag indicating whether ``h`` is in the DFT domain. + jit: If `True`, jit the evaluation, adjoint, and gram + functions of the LinearOperator. """ if ndims is None: @@ -138,7 +146,8 @@ def __init__( output_shape = np.broadcast_shapes(self.h_dft.shape, input_shape) except ValueError: raise ValueError( - f"h shape after padding was {self.h_dft.shape}, needs to be compatible for broadcasting with {input_shape}." + f"h shape after padding was {self.h_dft.shape}, needs to be compatible " + f"for broadcasting with {input_shape}." ) self.batch_axes = tuple( @@ -233,15 +242,17 @@ def __truediv__(self, scalar): def from_operator( H: Operator, ndims: Optional[int] = None, center: Optional[Shape] = None, jit: bool = True ): - r"""Construct a CircularConvolve version of a given operator, + r"""Construct a CircularConvolve version of a given operator. + + Construct a CircularConvolve version of a given operator, which is assumed to be linear and shift invariant (LSI). Args: H: Input operator. ndims: Number of trailing dims over which the H acts. - center: Location at which to place the Kronecker delta. For LSI inputs, - this will not matter. Defaults to the center of H.input_shape, i.e., - (n_1 // 2, n_2 // 2, ...). + center: Location at which to place the Kronecker delta. For + LSI inputs, this will not matter. Defaults to the center + of H.input_shape, i.e., (n_1 // 2, n_2 // 2, ...). jit: If ``True``, jit the resulting `CircularConvolve`. """ @@ -270,16 +281,20 @@ def from_operator( def _gradient_filters(ndim: int, axes: Shape, shape: Shape, dtype: DType = snp.float32) -> JaxArray: - r"""Construct a set of filters for computing gradients in the frequency domain. + r"""Construct filters for computing gradients in the frequency domain. + + Construct a set of filters for computing gradients in the frequency domain. Args: - ndim: Total number of dimensions in array in which gradients are to be computed - axes: Axes on which gradients are to be computed - shape: Shape of axes on which gradients are to be computed - dtype: Data type of output arrays + ndim: Total number of dimensions in array in which gradients are + to be computed. + axes: Axes on which gradients are to be computed. + shape: Shape of axes on which gradients are to be computed. + dtype: Data type of output arrays. Returns: - An array of frequency domain gradient operators :math:`\hat{G}_i` + An array of frequency domain gradient operators + :math:`\hat{G}_i`. """ g = snp.zeros( [ diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index bbc360db6..55d594380 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -42,19 +42,23 @@ def __init__( jit: bool = True, **kwargs, ): - r""" - Wraps :func:`jax.scipy.signal.convolve` as a :class:`.LinearOperator`. + r"""Wrap :func:`jax.scipy.signal.convolve` as a LinearOperator. + + Wrap :func:`jax.scipy.signal.convolve` as a + :class:`.LinearOperator`. Args: - h: Convolutional filter. Must have same number of dimensions as - `len(input_shape)`. + h: Convolutional filter. Must have same number of dimensions + as `len(input_shape)`. input_shape: Shape of input array. - input_dtype: `dtype` for input argument. - Defaults to `float32`. If ``LinearOperator`` implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. - mode: A string indicating the size of the output. One of "full", "valid", "same". - Defaults to "full". - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + input_dtype: `dtype` for input argument. Defaults to + `float32`. If ``LinearOperator`` implements + complex-valued operations, this must be `complex64` for + proper adjoint and gradient calculation. + mode: A string indicating the size of the output. One of + "full", "valid", "same". Defaults to "full". + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ @@ -171,18 +175,19 @@ def __init__( r""" Args: - x: Convolutional filter. Must have same number of dimensions as - `len(input_shape)`. + x: Convolutional filter. Must have same number of dimensions + as `len(input_shape)`. input_shape: Shape of input array. - input_dtype: `dtype` for input argument. - Defaults to `float32`. If :class:`.LinearOperator` implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. - mode: A string indicating the size of the output. One of "full", "valid", "same". - Defaults to "full". - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + input_dtype: `dtype` for input argument. Defaults to + `float32`. If :class:`.LinearOperator` implements + complex-valued operations, this must be `complex64` for + proper adjoint and gradient calculation. + mode: A string indicating the size of the output. One of + "full", "valid", "same". Defaults to "full". + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. - """ self.x: JaxArray # : Fixed signal to convolve with diff --git a/scico/linop/_dft.py b/scico/linop/_dft.py index 2efacb1b9..a58404182 100644 --- a/scico/linop/_dft.py +++ b/scico/linop/_dft.py @@ -25,7 +25,7 @@ class DFT(LinearOperator): - r"""N-dimensional Discrete Fourier Transform""" + r"""N-dimensional Discrete Fourier Transform.""" def __init__( self, input_shape: Shape, output_shape: Optional[Shape] = None, jit: bool = True, **kwargs @@ -33,10 +33,12 @@ def __init__( r""" Args: input_shape: Shape of input array. - output_shape: Shape of transformed output. Along any axis, if the given - output_shape is larger than the input, the input is padded with zeros. - If None, the shape of the input is used. - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + output_shape: Shape of transformed output. Along any axis, + if the given output_shape is larger than the input, the + input is padded with zeros. If None, the shape of the + input is used. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. """ if output_shape is None: @@ -67,13 +69,16 @@ def _eval(self, x: JaxArray) -> JaxArray: return snp.fft.fftn(x, s=self.output_shape) def inv(self, z: JaxArray, truncate: bool = True) -> JaxArray: - """Compute the inverse of this LinearOperator applied to the point `z`. + """Compute the inverse of this LinearOperator. + + Compute the inverse of this LinearOperator applied to `z`. Args: - z: Array to take inverse DFT - truncate: If `True`, the inverse DFT is truncated to be `input_shape`. - This may be used when this DFT LinearOperator applies zero padding before - computing the DFT. + z: Array to take inverse DFT. + truncate: If `True`, the inverse DFT is truncated to be + `input_shape`. This may be used when this DFT + LinearOperator applies zero padding before computing the + DFT. """ y = snp.fft.ifftn(z) if truncate: diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 1a6faf4a1..08d0c378a 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -29,9 +29,9 @@ class FiniteDifference(LinearOperatorStack): """Finite Difference operator. - Computes finite differences along the specified axes, returning the results in - a `DeviceArray` (whenever possible) or `BlockArray`. See :class:`LinearOperatorStack` for details - on how this choice is made. + Computes finite differences along the specified axes, returning the + results in a `DeviceArray` (whenever possible) or `BlockArray`. See + :class:`LinearOperatorStack` for details on how this choice is made. Example ------- @@ -58,16 +58,21 @@ def __init__( r""" Args: input_shape: Shape of input array. - input_dtype: `dtype` for input argument. - Defaults to `float32`. If `LinearOperator` implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. - axes: Axis or axes over which to apply finite difference operator. If not specified, - or `None`, differences are evaluated along all axes. - append: Value to append to the input along each axis before taking differences. - Zero is a typical choice. If not `None`, `circular` must be ``False``. - circular: If ``True``, perform circular differences, i.e., include x[-1] - x[0]. - If ``True``, `append` must be `None`. - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + input_dtype: `dtype` for input argument. Defaults to + `float32`. If `LinearOperator` implements complex-valued + operations, this must be `complex64` for proper adjoint + and gradient calculation. + axes: Axis or axes over which to apply finite difference + operator. If not specified, or `None`, differences are + evaluated along all axes. + append: Value to append to the input along each axis before + taking differences. Zero is a typical choice. If not + `None`, `circular` must be ``False``. + circular: If ``True``, perform circular differences, i.e., + include x[-1] - x[0]. If ``True``, `append` must be + `None`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. """ self.axes = parse_axes(axes, input_shape) @@ -105,14 +110,17 @@ def __init__( Args: axis: Axis over which to apply finite difference operator. input_shape: Shape of input array. - input_dtype: `dtype` for input argument. - Defaults to `float32`. If `LinearOperator` implements complex-valued operations, - this must be `complex64` for proper adjoint and gradient calculation. - append: Value to append to the input along `axis` before taking differences. - Defaults to 0. - circular: If ``True``, perform circular differences, i.e., include x[-1] - x[0]. - If ``True``, `append` must be `None`. - jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. + input_dtype: `dtype` for input argument. Defaults to + `float32`. If `LinearOperator` implements complex-valued + operations, this must be `complex64` for proper adjoint + and gradient calculation. + append: Value to append to the input along `axis` before + taking differences. Defaults to 0. + circular: If ``True``, perform circular differences, i.e., + include x[-1] - x[0]. If ``True``, `append` must be + `None`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. """ if not isinstance(axis, int): @@ -120,7 +128,8 @@ def __init__( if axis >= len(input_shape): raise ValueError( - f"Invalid axis {axis} specified; `axis` must be less than `len(input_shape)`={len(input_shape)}" + f"Invalid axis {axis} specified; `axis` must be less than " + f"`len(input_shape)`={len(input_shape)}" ) self.axis = axis diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index e3858ae3b..d586a7248 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -71,8 +71,8 @@ class MatrixOperator(LinearOperator): def __init__(self, A: JaxArray): """ Args: - A: Dense array. The action of the created LinearOperator will implement matrix - multiplication with `A`. + A: Dense array. The action of the created LinearOperator will + implement matrix multiplication with `A`. """ self.A: JaxArray #: Dense array implementing this matrix @@ -106,8 +106,8 @@ def __call__(self, other): ) else: raise ValueError( - """Cannot compute MatrixOperator-LinearOperator product, {other.output_shape} does not - match {self.input_shape}""" + "Cannot compute MatrixOperator-LinearOperator product, " + f"{other.output_shape} does not match {self.input_shape}" ) else: return self._eval(other) @@ -136,8 +136,8 @@ def __rsub__(self, other): def __neg__(self): return MatrixOperator(-self.A) - # Could write another wrapper for mul, truediv, and rtuediv, but - # there is no operator.__rtruediv__; have to write that case out manually anyway. + # Could write another wrapper for mul, truediv, and rtuediv, bu there is + # no operator.__rtruediv__; have to write that case out manually anyway. def __mul__(self, other): if np.isscalar(other): return MatrixOperator(other * self.A) @@ -191,33 +191,48 @@ def __getitem__(self, key): @property def T(self): - """Return a :class:`.MatrixOperator` corresponding to the transpose of this matrix""" + """Transpose of this :class:`.MatrixOperator`. + + Return a :class:`.MatrixOperator` corresponding to the transpose + of this matrix. + """ return MatrixOperator(self.A.T) @property def H(self): - """Return a :class:`.MatrixOperator` corresponding to the Hermitian (conjugate) transpose of this matrix""" + """Hermitian (conjugate) transpose of this :class:`.MatrixOperator`. + + Return a :class:`.MatrixOperator` corresponding to the Hermitian + (conjugate) transpose of this matrix. + """ return MatrixOperator(self.A.conj().T) def conj(self): - """Return a :class:`.MatrixOperator` with complex conjugated elements""" + """Complex conjugate of this :class:`.MatrixOperator`. + + Return a :class:`.MatrixOperator` with complex conjugated + elements. + """ return MatrixOperator(A=self.A.conj()) def adj(self, y): return self.A.conj().T @ y def to_array(self): - """Returns a :class:`numpy.ndarray` containing ``self.A``""" + """Return a :class:`numpy.ndarray` containing ``self.A``.""" return self.A.copy() @property def gram_op(self): - """Returns a new :class:`.LinearOperator` ``G`` such that ``G(x) = A.adj(A(x)))``""" + """Gram operator of this :class:`.MatrixOperator`. + + Return a new :class:`.LinearOperator` ``G`` such that + ``G(x) = A.adj(A(x)))``.""" return MatrixOperator(A=self.A.conj().T @ self.A) def norm(self, ord=None, axis=None, keepdims=False): """Compute the norm of the dense matrix `self.A`. - Calls :func:`scico.numpy.norm` on the dense matrix `self.A`. + Call :func:`scico.numpy.norm` on the dense matrix `self.A`. """ return snp.linalg.norm(self.A, ord=ord, axis=axis, keepdims=keepdims) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 0f75cee4c..4596d6092 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -35,10 +35,10 @@ def __init__( r""" Args: ops: Operators to stack. - collapse: If `True` and the output would be a `BlockArray` with shape - ((m, n, ...), (m, n, ...), ...), the output is instead a `DeviceArray` - with shape (S, m, n, ...) where S is the length of `ops`. - Defaults to True. + collapse: If `True` and the output would be a `BlockArray` + with shape ((m, n, ...), (m, n, ...), ...), the output is + instead a `DeviceArray` with shape (S, m, n, ...) where S + is the length of `ops`. Defaults to True. jit: see `jit` in :class:`LinearOperator`. """ @@ -50,13 +50,15 @@ def __init__( input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): raise ValueError( - f"expected all `LinearOperator`s to have the same input shapes, but got {input_shapes}" + "expected all `LinearOperator`s to have the same input shapes, " + f"but got {input_shapes}" ) input_dtypes = [op.input_dtype for op in ops] if not all(input_dtypes[0] == s for s in input_dtypes): raise ValueError( - f"expected all `LinearOperator`s to have the same input dtype, but got {input_dtypes}." + "expected all `LinearOperator`s to have the same input dtype, " + f"but got {input_dtypes}." ) self.collapse = collapse @@ -100,12 +102,13 @@ def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) def scale_ops(self, scalars: JaxArray): - """Return a copy of `self` with each operator scaled by the corresponding - entry in `scalars` + """Scale component linear operators. - Args: - scalars: List or array of scalars to use + Return a copy of `self` with each operator scaled by the + corresponding entry in `scalars`. + Args: + scalars: List or array of scalars to use. """ if len(scalars) != len(self.ops): raise ValueError("expected `scalars` to be the same length as `self.ops`") diff --git a/scico/linop/optics.py b/scico/linop/optics.py index c109bc84d..c04a23352 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -91,15 +91,15 @@ def __init__( ): r""" Args: - input_shape: Shape of input array. Can be a tuple of length + input_shape: Shape of input array. Can be a tuple of length 1 or 2. dx: Sampling interval at source plane. If a float and `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, must have same length as `input_shape`. - k0 : Illumination wavenumber; 2Ï€/wavelength - z : Propagation distance - pad_factor: Amount of padding to apply in DFT step + k0: Illumination wavenumber; 2Ï€/wavelength. + z: Propagation distance. + pad_factor: Amount of padding to apply in DFT step. """ ndim = len(input_shape) # 1 or 2 dimensions @@ -176,19 +176,19 @@ class AngularSpectrumPropagator(Propagator): .. math :: (A \mb{u})(x, y) = \iint_{-\infty}^{\infty} \mb{\hat{u}}(k_x, k_y) e^{j \sqrt{k_0^2 - k_x^2 - k_y^2} - \abs{z}} e^{j (x k_x + y k_y) } d k_x \ d k_y + \abs{z}} e^{j (x k_x + y k_y) } d k_x \ d k_y \;, where the :math:`\mb{\hat{u}}` is the Fourier transform of the field in the plane :math:`z=0`, given by .. math :: \mb{\hat{u}}(k_x, k_y) = \iint_{-\infty}^{\infty} - \mb{u}(x, y) e^{- j (x k_x + y k_y)} d k_x \ d k_y. + \mb{u}(x, y) e^{- j (x k_x + y k_y)} d k_x \ d k_y \;. The angular spectrum propagator can be written .. math :: - A\mb{u} = F^{-1} D F \mb{u} + A\mb{u} = F^{-1} D F \mb{u} \;, where :math:`F` is the Fourier transform with respect to :math:`(x, y)` and @@ -196,16 +196,14 @@ class AngularSpectrumPropagator(Propagator): .. math :: D = \mathrm{diag}\left(\exp \left\{ j \sqrt{k_0^2 - k_x^2 - k_y^2} \abs{z} - \right\} \right) - + \right\} \right) \;. The propagator is adequately sampled when :cite:`voelz-2009-digital` .. math :: (\Delta x)^2 \geq \frac{\pi}{k_0 N} \sqrt{ (\Delta x)^2 N^2 - + 4 z^2} - + + 4 z^2} \;. """ def __init__( @@ -223,9 +221,9 @@ def __init__( Args: input_shape: Shape of input array. Can be a tuple of length 2 or 3. - dx : Spatial sampling rate. - k0 : Illumination wavenumber. - z : Propagation distance. + dx: Spatial sampling rate. + k0: Illumination wavenumber. + z: Propagation distance. pad_factor: Amount of padding to apply in DFT step. jit: If ``True``, call :meth:`.jit()` on this LinearOperator to jit the forward, adjoint, and gram functions. Same as @@ -251,11 +249,12 @@ def __init__( def check_sampling(self): r"""Verify the angular spectrum kernel is not aliased. - Checks the condition for adequate sampling, :cite:`voelz-2009-digital` + Checks the condition for adequate sampling, + :cite:`voelz-2009-digital` .. math :: - (\Delta x)^2 \geq \frac{\pi}{k_0 N} \sqrt{ (\Delta x)^2 N^2 + 4 z^2} - + (\Delta x)^2 \geq \frac{\pi}{k_0 N} \sqrt{ (\Delta x)^2 N^2 + + 4 z^2} \;. Returns: True if the angular spectrum kernel is adequately sampled, @@ -270,13 +269,13 @@ def check_sampling(self): return False def pinv(self, y): - """Applies pseudoinverse of Angular Spectrum propagator.""" + """Apply pseudoinverse of Angular Spectrum propagator.""" diag_inv = safe_divide(1, self.D.diagonal) return self.F.inv(diag_inv * self.F(y)) class FresnelPropagator(Propagator): - r"""Fresnel Propagator + r"""Fresnel Propagator. Propagates a source field with coordinates :math:`(x, y, z_0)` to a destination plane at a distance :math:`z` with coordinates @@ -286,19 +285,19 @@ class FresnelPropagator(Propagator): .. math :: (A \mb{u})(x, y) = e^{j k_0 z} \iint_{-\infty}^{\infty} \mb{\hat{u}}(k_x, k_y) e^{-j \frac{z}{2 k_0} (k_x^2 + k_y^2) } - e^{j (x k_x + y k_y) } d k_x \ d k_y, + e^{j (x k_x + y k_y) } d k_x \ d k_y \;, where the :math:`\mb{\hat{u}}` is the Fourier transform of the field in the source plane, given by .. math :: \mb{\hat{u}}(k_x, k_y) = \iint_{-\infty}^{\infty} \mb{u}(x, y) - e^{- j (x k_x + y k_y)} d k_x \ d k_y. + e^{- j (x k_x + y k_y)} d k_x \ d k_y \;. The Fresnel propagator can be written .. math :: - A\mb{u} = F^{-1} D F \mb{u} + A\mb{u} = F^{-1} D F \mb{u} \;, where :math:`F` is the Fourier transform with respect to :math:`(x, y)` and @@ -337,10 +336,11 @@ def __init__( def check_sampling(self): r"""Verify the Fraunhofer propagation kernel is not aliased. - Checks the condition for adequate sampling, :cite:`voelz-2011-computational` + Checks the condition for adequate sampling, + :cite:`voelz-2011-computational` .. math :: - (\Delta x)^2 \geq \frac{2 \pi z }{k_0 N} + (\Delta x)^2 \geq \frac{2 \pi z }{k_0 N} \;. Returns: @@ -373,22 +373,24 @@ class FraunhoferPropagator(LinearOperator): \frac{e^{j k_0 z}}{j z} \mathrm{exp} \left\{ j \frac{k_0}{2 z} (x_D^2 + y_D^2) \right\}}_{\triangleq P(x_D, y_D)} \int \mb{u}(x_S, y_S) e^{-j \frac{k_0}{z} (x_D x_S + y_D y_S) - } dx_S \ dy_S. + } dx_S \ dy_S \;. Writing the Fourier transform of the field :math:`\mb{u}` as .. math :: - \hat{\mb{u}}(k_x, k_y) = \int e^{-j (k_x x + k_y y)} \mb{u}(x, y) dx \ dy, + \hat{\mb{u}}(k_x, k_y) = \int e^{-j (k_x x + k_y y)} + \mb{u}(x, y) dx \ dy \;, the action of this linear operator can be written .. math :: (A \mb{u})(x_D, y_D) = P(x_D, y_D) \ \hat{\mb{u}} - \left({\frac{k_0}{z} x_D, \frac{k_0}{z} y_D}\right). + \left({\frac{k_0}{z} x_D, \frac{k_0}{z} y_D}\right) \;. Ignoring multiplicative prefactors, the Fraunhofer propagated field is the Fourier transform of the source field, evaluated at - coordinates :math:`(k_x, k_y) = (\frac{k_0}{z} x_D, \frac{k_0}{z} y_D)`. + coordinates :math:`(k_x, k_y) = (\frac{k_0}{z} x_D, + \frac{k_0}{z} y_D)`. In general, the sampling intervals (and thus plane lengths) differ between source and destination planes. In particular, @@ -397,19 +399,20 @@ class FraunhoferPropagator(LinearOperator): .. math :: \begin{aligned} \Delta x_D &= \frac{2 \pi z}{k_0 L_S } \\ - L_D &= \frac{2 \pi z}{k_0 \Delta x_S } + L_D &= \frac{2 \pi z}{k_0 \Delta x_S } \;. \end{aligned} The sampling intervals and plane lengths coincide in the case of critical sampling: .. math :: - \Delta x_S = \sqrt{\frac{2 \pi z}{N k_0}} + \Delta x_S = \sqrt{\frac{2 \pi z}{N k_0}} \;. - The Fraunhofer phase :math:`P(x_D, y_D)` is adequately sampled when + The Fraunhofer phase :math:`P(x_D, y_D)` is adequately sampled + when .. math :: - \Delta x_S \geq \sqrt{\frac{2 \pi z}{N k_0}} + \Delta x_S \geq \sqrt{\frac{2 \pi z}{N k_0}} \;. """ def __init__( @@ -429,9 +432,9 @@ def __init__( `len(input_shape)==2` the same sampling interval is applied to both dimensions. If `dx` is a tuple, must have same length as `input_shape`. - k0 : Illumination wavenumber; 2Ï€/wavelength - z : Propagation distance - jit: If ``True``, jit the evaluation, adjoint, and gram + k0: Illumination wavenumber; 2Ï€/wavelength. + z: Propagation distance + jit: If ``True``, jit the evaluation, adjoint, and gram functions of this LinearOperator. Default: True. """ @@ -514,7 +517,7 @@ def check_sampling(self): :cite:`voelz-2011-computational` .. math :: - \Delta x^2 \geq \frac{2 \pi z }{k_0 N} + \Delta x^2 \geq \frac{2 \pi z }{k_0 N} \;. Returns: True if the Fresnel propagation kernel is adequately sampled, diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index 733ef3407..1b44efbed 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -38,7 +38,8 @@ class ParallelBeamProjector(LinearOperator): r"""Parallel beam Radon transform based on the ASTRA toolbox. Perform tomographic projection of an image at specified angles, - using the `ASTRA toolbox `_. + using the + `ASTRA toolbox `_. """ def __init__( @@ -67,9 +68,9 @@ def __init__( so not following these requirements may have unpredictable results. See `original ASTRA documentation `_. - detector_spacing: Spacing between detector elements - det_count: Number of detector elements - angles: Array of projection angles. + detector_spacing: Spacing between detector elements. + det_count: Number of detector elements. + angles: Array of projection angles. device: Specifies device for projection operation. One of ["auto", "gpu", "cpu"]. If "auto", a GPU is used if available. Otherwise, the CPU is used. @@ -152,7 +153,8 @@ def fbp(self, sino: JaxArray, filter_type: str = "Ram-Lak") -> JaxArray: Args: sino: Sinogram to reconstruct. - filter_type: Which filter to use, see `cfg.FilterType` in ``_. + filter_type: Which filter to use, see `cfg.FilterType` in + ``_. """ # Just use the CPU FBP alg for now; hitting memory issues with GPU one. diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 7a89292bd..c34f72af5 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -55,8 +55,8 @@ def __init__( input_shape: Shape of the input array. angles: Array of projection angles in radians, should be increasing. - num_channels: Number of pixels in the sinogram - is_masked: If True, the valid region of the image is + num_channels: Number of pixels in the sinogram. + is_masked: If True, the valid region of the image is determined by a mask defined as the circle inscribed within the image boundary. Otherwise, the whole image array is taken into account by projections. @@ -164,7 +164,6 @@ class SVMBIRWeightedSquaredL2Loss(WeightedSquaredL2Loss): quadrant, but the the loss, :math:`\alpha l(\mb{y}, A(\mb{x}))`, is unaffected by this setting and still evaluates to finite values when :math:`\mb{x}` is not in the non-negative quadrant. - """ def __init__( @@ -177,9 +176,9 @@ def __init__( r"""Initialize a :class:`SVMBIRWeightedSquaredL2Loss` object. Args: - y : Sinogram measurement. - A : Forward operator. - scale : Scaling parameter. + y: Sinogram measurement. + A: Forward operator. + scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If None, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the diff --git a/scico/loss.py b/scico/loss.py index f2189493c..edaee1e8c 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -40,13 +40,15 @@ def wrapper(self, other): class Loss(functional.Functional): - r"""Generic Loss function. + r"""Generic loss function. - .. math:: - \alpha l(\mb{y}, A(\mb{x})) \; + Generic loss function - where :math:`\alpha` is the scaling parameter and :math:`l(\cdot)` is the loss functional. + .. math:: + \alpha l(\mb{y}, A(\mb{x})) \;, + where :math:`\alpha` is the scaling parameter and :math:`l(\cdot)` is + the loss functional. """ def __init__( @@ -58,9 +60,10 @@ def __init__( r"""Initialize a :class:`Loss` object. Args: - y : Measurement. - A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`. - scale : Scaling parameter. Default: 0.5. + y: Measurement. + A: Forward operator. Defaults to None. If None, ``self.A`` is + a :class:`.Identity`. + scale: Scaling parameter. Default: 0.5. """ self.y = ensure_on_device(y) @@ -105,17 +108,18 @@ def set_scale(self, new_scale: float): class WeightedSquaredL2Loss(Loss): - r""" - Weighted squared :math:`\ell_2` loss. + r"""Weighted squared :math:`\ell_2` loss. + + Weighted squared :math:`\ell_2` loss .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = - \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right)\; + \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - + A(\mb{x})\right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is None, reverts to the behavior of :class:`.SquaredL2Loss`. - """ def __init__( @@ -130,9 +134,9 @@ def __init__( r"""Initialize a :class:`WeightedSquaredL2Loss` object. Args: - y : Measurement. - A : Forward operator. If None, defaults to :class:`.Identity`. - scale : Scaling parameter. + y: Measurement. + A: Forward operator. If None, defaults to :class:`.Identity`. + scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If None, defaults to :class:`.Identity`. """ @@ -203,11 +207,11 @@ def prox( @property def hessian(self) -> linop.LinearOperator: - r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a - :class:`scico.linop.LinearOperator` corresponding to the Hessian - :math:`2 \alpha \mathrm{A^H W A}`. + r"""Compute the hessian of a linear operator. - Otherwise not implemented. + If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a + :class:`scico.linop.LinearOperator` corresponding to the Hessian + :math:`2 \alpha \mathrm{A^H W A}`. Otherwise not implemented. """ A = self.A W = self.W @@ -220,19 +224,20 @@ def hessian(self) -> linop.LinearOperator: ) else: raise NotImplementedError( - f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; must be LinearOperator" + f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; " + "must be LinearOperator" ) class SquaredL2Loss(WeightedSquaredL2Loss): - r""" - Squared :math:`\ell_2` loss. + r"""Squared :math:`\ell_2` loss. + + Squared :math:`\ell_2` loss .. math:: - \alpha \norm{\mb{y} - A(\mb{x})}_2^2 \; + \alpha \norm{\mb{y} - A(\mb{x})}_2^2 \;, where :math:`\alpha` is the scaling parameter. - """ def __init__( @@ -245,19 +250,21 @@ def __init__( r"""Initialize a :class:`SquaredL2Loss` object. Args: - y : Measurement. - A : Forward operator. If None, defaults to :class:`.Identity`. - scale : Scaling parameter. + y: Measurement. + A: Forward operator. If None, defaults to :class:`.Identity`. + scale: Scaling parameter. """ super().__init__(y=y, A=A, scale=scale, W=None, prox_kwargs=prox_kwargs) class PoissonLoss(Loss): - r""" + r"""Poisson negative log likelihood loss. + Poisson negative log likelihood loss .. math:: - \alpha \left( \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + \log(y_i!) \right) + \alpha \left( \sum_i [A(x)]_i - y_i \log\left( [A(x)]_i \right) + + \log(y_i!) \right) \;, where :math:`\alpha` is the scaling parameter. """ @@ -268,12 +275,13 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, ): - r"""Initialize a :class:`Loss` object. + r"""Initialize a :class:`PoissonLoss` object. Args: - y : Measurement. - A : Forward operator. Defaults to None. If None, ``self.A`` is a :class:`.Identity`. - scale : Scaling parameter. Default: 0.5. + y: Measurement. + A: Forward operator. Defaults to None. If None, ``self.A`` + is a :class:`.Identity`. + scale: Scaling parameter. Default: 0.5. """ y = ensure_on_device(y) super().__init__(y=y, A=A, scale=scale) diff --git a/scico/math.py b/scico/math.py index 6ef7033bf..46d7b327e 100644 --- a/scico/math.py +++ b/scico/math.py @@ -25,8 +25,8 @@ def safe_divide( """Return `x/y`, with 0 instead of NaN where `y` is 0. Args: - x: Numerator - y: Denominator + x: Numerator. + y: Denominator. Returns: `x / y` with 0 wherever `y == 0`. @@ -38,17 +38,19 @@ def safe_divide( def rel_res(ax: Union[BlockArray, JaxArray], b: Union[BlockArray, JaxArray]) -> float: r"""Relative residual of the solution to a linear equation. - The standard relative residual for the linear system :math:`A \mathbf{x} = \mathbf{b}` - is :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / \|\mathbf{b}\|_2`. This function computes - a variant :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / \max(\|A\mathbf{x}\|_2, - \|\mathbf{b}\|_2)` that is robust to the case :math:`\mathbf{b} = 0`. + The standard relative residual for the linear system + :math:`A \mathbf{x} = \mathbf{b}` is :math:`\|\mathbf{b} - + A \mathbf{x}\|_2 / \|\mathbf{b}\|_2`. This function computes a + variant :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / + \max(\|A\mathbf{x}\|_2, \|\mathbf{b}\|_2)` that is robust to the case + :math:`\mathbf{b} = 0`. Args: ax: Linear component :math:`A \mathbf{x}` of equation. b: Constant component :math:`\mathbf{b}` of equation. Returns: - x: Relative residual value. + Relative residual value. """ nrm = max(snp.linalg.norm(ax.ravel()), snp.linalg.norm(b.ravel())) @@ -62,10 +64,11 @@ def is_real_dtype(dtype: DType) -> bool: """Determine whether a dtype is real. Args: - dtype: A numpy or scico.numpy dtype (e.g. np.float32, snp.complex64) + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). Returns: - False if the dtype is complex, otherwise True + False if the dtype is complex, otherwise True. """ return snp.dtype(dtype).kind != "c" @@ -74,10 +77,11 @@ def is_complex_dtype(dtype: DType) -> bool: """Determine whether a dtype is complex. Args: - dtype: A numpy or scico.numpy dtype (e.g. np.float32, snp.complex64) + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). Returns: - True if the dtype is complex, otherwise False + True if the dtype is complex, otherwise False. """ return snp.dtype(dtype).kind == "c" @@ -90,7 +94,8 @@ def real_dtype(dtype: DType) -> DType: `np.float32`. Args: - dtype: A complex numpy or scico.numpy dtype, e.g. np.complex64, np.complex128 + dtype: A complex numpy or scico.numpy dtype (e.g. np.complex64, + np.complex128). Returns: The real dtype corresponding to the input dtype @@ -107,10 +112,11 @@ def complex_dtype(dtype: DType) -> DType: `np.complex64`. Args: - dtype: A real numpy or scico.numpy dtype, e.g. np.float32, np.float64 + dtype: A real numpy or scico.numpy dtype (e.g. np.float32, + np.float64). Returns: - The complex dtype corresponding to the input dtype + The complex dtype corresponding to the input dtype. """ return (snp.zeros(1, dtype) + 1j).dtype diff --git a/scico/metric.py b/scico/metric.py index b984c5645..a9f89cd29 100644 --- a/scico/metric.py +++ b/scico/metric.py @@ -21,46 +21,42 @@ def mae(reference: Union[JaxArray, BlockArray], comparison: Union[JaxArray, BlockArray]) -> float: - """ - Compute Mean Absolute Error (MAE) between two images. + """Compute Mean Absolute Error (MAE) between two images. Args: - reference: Reference image - comparison: Comparison image + reference: Reference image. + comparison: Comparison image. Returns: - MAE between `reference` and `comparison` + MAE between `reference` and `comparison`. """ return snp.mean(snp.abs(reference - comparison).ravel()) def mse(reference: Union[JaxArray, BlockArray], comparison: Union[JaxArray, BlockArray]) -> float: - """ - Compute Mean Squared Error (MSE) between two images. + """Compute Mean Squared Error (MSE) between two images. Args: - reference : Reference image - comparison : Comparison image + reference : Reference image. + comparison : Comparison image. Returns: - MSE between `reference` and `comparison` + MSE between `reference` and `comparison`. """ return snp.mean(snp.abs(reference - comparison).ravel() ** 2) def snr(reference: Union[JaxArray, BlockArray], comparison: Union[JaxArray, BlockArray]) -> float: - - """ - Compute Signal to Noise Ratio (SNR) of two images. + """Compute Signal to Noise Ratio (SNR) of two images. Args: - reference: Reference image - comparison: Comparison image + reference: Reference image. + comparison: Comparison image. Returns: - SNR of `comparison` with respect to `reference` + SNR of `comparison` with respect to `reference`. """ dv = snp.var(reference) @@ -82,14 +78,14 @@ def psnr( (i.e. :math:`2^b-1` for a :math:`b` bit representation). Args: - reference: Reference image - comparison: Comparison image - signal_range: Signal range, either the - value to use (e.g. 255 for 8 bit samples) or None, in which case - the actual range of the reference signal is used + reference: Reference image. + comparison: Comparison image. + signal_range: Signal range, either the value to use (e.g. 255 + for 8 bit samples) or None, in which case the actual range + of the reference signal is used. Returns: - PSNR of `comparison` with respect to `reference` + PSNR of `comparison` with respect to `reference`. """ if signal_range is None: @@ -104,17 +100,18 @@ def isnr( degraded: Union[JaxArray, BlockArray], restored: Union[JaxArray, BlockArray], ) -> float: - """ + """Compute Improvement Signal to Noise Ratio (ISNR). + Compute Improvement Signal to Noise Ratio (ISNR) for reference, degraded, and restored images. Args: - reference: Reference image - degraded: Degraded image - restored: Restored image + reference: Reference image. + degraded: Degraded image. + restored: Restored image. Returns: - ISNR of `restored` with respect to `reference` and `degraded` + ISNR of `restored` with respect to `reference` and `degraded`. """ msedeg = mse(reference, degraded) @@ -125,16 +122,17 @@ def isnr( def bsnr(blurry: Union[JaxArray, BlockArray], noisy: Union[JaxArray, BlockArray]) -> float: - """ + """Compute Blurred Signal to Noise Ratio (BSNR). + Compute Blurred Signal to Noise Ratio (BSNR) for a blurred and noisy image. Args: - blurry: Blurred noise free image - noisy: Blurred image with additive noise + blurry: Blurred noise free image. + noisy: Blurred image with additive noise. Returns: - BSNR of `noisy` with respect to `blurry` and `degraded` + BSNR of `noisy` with respect to `blurry` and `degraded`. """ blrvar = snp.var(blurry) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index d9361ecd9..713f55ec2 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -7,11 +7,18 @@ """Wrapped versions of :mod:`jax.numpy` functions. -This modules consists of functions from :mod:`jax.numpy`. Some of these functions are wrapped to support compatibility with :class:`scico.blockarray.BlockArray` and are documented here. The remaining functions are imported directly from :mod:`jax.numpy`. While they can be imported from the :mod:`scico.numpy` namespace, they are not documented here; please consult the documentation for the source module :mod:`jax.numpy`. +This modules consists of functions from :mod:`jax.numpy`. Some of these +functions are wrapped to support compatibility with +:class:`scico.blockarray.BlockArray` and are documented here. The +remaining functions are imported directly from :mod:`jax.numpy`. While +they can be imported from the :mod:`scico.numpy` namespace, they are not +documented here; please consult the documentation for the source module +:mod:`jax.numpy`. .. todo:: - Provide detailed discussion of purpose of wrapping functions, and how to determine which functions are wrapped and which are not. + Provide detailed discussion of purpose of wrapping functions, and how + to determine which functions are wrapped and which are not. """ import sys diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py index 1b9bb2a55..56f9d92ba 100644 --- a/scico/numpy/_create.py +++ b/scico/numpy/_create.py @@ -27,8 +27,8 @@ def zeros( If `shape` is a list of tuples, returns a BlockArray of zeros. Args: - shape : Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. + shape: Shape of the new array. + dtype: Desired data-type of the array. Default is `np.float32`. """ if is_nested(shape): return BlockArray.zeros(shape, dtype=dtype) @@ -42,8 +42,8 @@ def ones(shape: Union[Shape, BlockShape], dtype: DType = np.float32) -> Union[Ja If `shape` is a list of tuples, returns a BlockArray of ones. Args: - shape : Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. + shape: Shape of the new array. + dtype: Desired data-type of the array. Default is `np.float32`. """ if is_nested(shape): return BlockArray.ones(shape, dtype=dtype) @@ -59,8 +59,8 @@ def empty( If `shape` is a list of tuples, returns a BlockArray of zeros. Args: - shape : Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. + shape: Shape of the new array. + dtype: Desired data-type of the array. Default is `np.float32`. """ if is_nested(shape): return BlockArray.empty(shape, dtype=dtype) @@ -75,13 +75,14 @@ def full( ) -> Union[JaxArray, BlockArray]: """Return a new array of given shape and type, filled with `fill_value`. - If `shape` is a list of tuples, returns a BlockArray filled with `fill_value`. + If `shape` is a list of tuples, returns a BlockArray filled with + `fill_value`. Args: - shape : Shape of the new array. + shape: Shape of the new array. fill_value : Fill value. - dtype: Desired data-type of the array. The default, None, - means `np.array(fill_value).dtype` + dtype: Desired data-type of the array. The default, None, + means `np.array(fill_value).dtype`. """ if dtype is None: dtype = jax.dtypes.canonicalize_dtype(type(fill_value)) @@ -94,13 +95,14 @@ def full( def zeros_like(x: Union[JaxArray, BlockArray], dtype=None): """Return an array of zeros with same shape and type as a given array. - If input is a BlockArray, returns a BlockArray of zeros with same shape and type - as a given array. - + If input is a BlockArray, returns a BlockArray of zeros with same + shape and type as a given array. Args: - x (array like): The shape and dtype of `x` define these attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the result. + x (array like): The shape and dtype of `x` define these + attributes on the returned array. + dtype (data-type, optional): Overrides the data type of the + result. """ if dtype is None: dtype = jax.dtypes.canonicalize_dtype(x.dtype) @@ -114,15 +116,17 @@ def zeros_like(x: Union[JaxArray, BlockArray], dtype=None): def empty_like(x: Union[JaxArray, BlockArray], dtype: DType = None): """Return an array of zeros with same shape and type as a given array. - If input is a BlockArray, returns a BlockArray of zeros with same shape and type - as a given array. - - Note: like :func:`jax.numpy.empty_like`, this does not return an uninitalized array. + If input is a BlockArray, returns a BlockArray of zeros with same + shape and type as a given array. + Note: like :func:`jax.numpy.empty_like`, this does not return an + uninitalized array. Args: - x (array like): The shape and dtype of `x` define these attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the result. + x (array like): The shape and dtype of `x` define these + attributes on the returned array. + dtype (data-type, optional): Overrides the data type of the + result. """ if dtype is None: dtype = jax.dtypes.canonicalize_dtype(x.dtype) @@ -136,13 +140,14 @@ def empty_like(x: Union[JaxArray, BlockArray], dtype: DType = None): def ones_like(x: Union[JaxArray, BlockArray], dtype: DType = None): """Return an array of ones with same shape and type as a given array. - If input is a BlockArray, returns a BlockArray of ones with same shape and type - as a given array. - + If input is a BlockArray, returns a BlockArray of ones with same + shape and type as a given array. Args: - x (array like): The shape and dtype of `x` define these attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the result. + x (array like): The shape and dtype of `x` define these + attributes on the returned array. + dtype (data-type, optional): Overrides the data type of the + result. """ if dtype is None: dtype = jax.dtypes.canonicalize_dtype(x.dtype) @@ -156,16 +161,18 @@ def ones_like(x: Union[JaxArray, BlockArray], dtype: DType = None): def full_like( x: Union[JaxArray, BlockArray], fill_value: Union[float, complex], dtype: DType = None ): - """Return an array of with same shape and type as a given array, filled with `fill_value`. - - If input is a BlockArray, returns a BlockArray of `fill_value` with same shape and type - as a given array. + """Return an array filled with `fill_value`. + Return an array of with same shape and type as a given array, filled + with `fill_value`. If input is a BlockArray, returns a BlockArray of + `fill_value` with same shape and type as a given array. Args: - x (array like): The shape and dtype of `x` define these attributes on the returned array. - fill_value (scalar): Fill value. - dtype (data-type, optional): Overrides the data type of the result. + x (array like): The shape and dtype of `x` define these + attributes on the returned array. + fill_value (scalar): Fill value. + dtype (data-type, optional): Overrides the data type of the + result. """ if dtype is None: dtype = jax.dtypes.canonicalize_dtype(x.dtype) diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 1dcd93afd..b7c747638 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -5,8 +5,7 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Tools to construct wrapped versions of :mod:`jax.numpy` functions. -""" +"""Tools to construct wrapped versions of :mod:`jax.numpy` functions.""" import re import types @@ -67,14 +66,17 @@ def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): def _get_module_functions(module): """Finds functions in module. - This function is a slightly modified version of :func:`jax._src.util.get_module_functions`. - Unlike the JAX version, this version will also return any - :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the module. + This function is a slightly modified version of + :func:`jax._src.util.get_module_functions`. Unlike the JAX version, + this version will also return any + :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the + module. Args: module: A Python module. Returns: - module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`. + module_fns: A dict of names mapped to functions, builtins or + ufuncs in `module`. """ module_fns = {} for key in dir(module): diff --git a/scico/numpy/fft.py b/scico/numpy/fft.py index 46dc094f1..0577f6acd 100644 --- a/scico/numpy/fft.py +++ b/scico/numpy/fft.py @@ -7,7 +7,13 @@ """Construct wrapped versions of :mod:`jax.numpy.fft` functions. -This modules consists of functions from :mod:`jax.numpy.fft`. Some of these functions are wrapped to support compatibility with :class:`scico.blockarray.BlockArray` and are documented here. The remaining functions are imported directly from :mod:`jax.numpy.fft`. While they can be imported from the :mod:`scico.numpy.fft` namespace, they are not documented here; please consult the documentation for the source module :mod:`jax.numpy.fft`. +This modules consists of functions from :mod:`jax.numpy.fft`. Some of +these functions are wrapped to support compatibility with +:class:`scico.blockarray.BlockArray` and are documented here. +The remaining functions are imported directly from :mod:`jax.numpy.fft`. +While they can be imported from the :mod:`scico.numpy.fft` namespace, +they are not documented here; please consult the documentation for the +source module :mod:`jax.numpy.fft`. """ import sys diff --git a/scico/numpy/linalg.py b/scico/numpy/linalg.py index 0d95f8259..e47191904 100644 --- a/scico/numpy/linalg.py +++ b/scico/numpy/linalg.py @@ -7,7 +7,13 @@ """Construct wrapped versions of :mod:`jax.numpy.linalg` functions. -This modules consists of functions from :mod:`jax.numpy.linalg`. Some of these functions are wrapped to support compatibility with :class:`scico.blockarray.BlockArray` and are documented here. The remaining functions are imported directly from :mod:`jax.numpy.linalg`. While they can be imported from the :mod:`scico.numpy.linalg` namespace, they are not documented here; please consult the documentation for the source module :mod:`jax.numpy.linalg`. +This modules consists of functions from :mod:`jax.numpy.linalg`. Some of +these functions are wrapped to support compatibility with +:class:`scico.blockarray.BlockArray` and are documented here. The +remaining functions are imported directly from :mod:`jax.numpy.linalg`. +While they can be imported from the :mod:`scico.numpy.linalg` namespace, +they are not documented here; please consult the documentation for the +source module :mod:`jax.numpy.linalg`. """ @@ -31,7 +37,10 @@ def _extract_if_matrix(x): def _matrixop_linalg_wrapper(func): - """Wraps :mod:`jax.numpy.linalg` functions for joint operation on `MatrixOperator` and `DeviceArray`""" + """Wrap :mod:`jax.numpy.linalg` functions. + + Wrap :mod:`jax.numpy.linalg` functions for joint operation on + `MatrixOperator` and `DeviceArray`.""" @wraps(func) def wrapper(*args, **kwargs): @@ -72,8 +81,11 @@ def wrapper(*args, **kwargs): # multidot is somewhat unique def multi_dot(arrays, *, precision=None): - """Computes the dot product of two or more arrays. - Wrapped to work with `MatrixOperator`s.""" + """Compute the dot product of two or more arrays. + + Compute the dot product of two or more arrays. + Wrapped to work with `MatrixOperator`s. + """ arrays_ = [_extract_if_matrix(_) for _ in arrays] return jla.multi_dot(arrays_, precision=precision) diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index 6a8570fd8..2d6d44bac 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -24,11 +24,11 @@ class BiConvolve(Operator): """BiConvolution operator. - A BiConvolve operator accepts a :class:`.BlockArray` input with two blocks - of equal ndims, and convolves the first block with the second. + A BiConvolve operator accepts a :class:`.BlockArray` input with two + blocks of equal ndims, and convolves the first block with the second. If `A` is a BiConvolve operator, then - A(BlockArray.array([x, h])) = jax.scipy.signal.convolve(x, h) + `A(BlockArray.array([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. """ @@ -41,12 +41,13 @@ def __init__( ): r""" Args: - input_shape: Shape of input BlockArray. Must correspond to a BlockArray - with two blocks of equal ndims. - input_dtype: `dtype` for input argument. Defaults to `float32`. - mode: A string indicating the size of the output. One of "full", "valid", "same". - Defaults to "full". - jit: If ``True``, jit the evaluation of this Operator. + input_shape: Shape of input BlockArray. Must correspond to a + BlockArray with two blocks of equal ndims. + input_dtype: `dtype` for input argument. Defaults to + `float32`. + mode: A string indicating the size of the output. One of + "full", "valid", "same". Defaults to "full". + jit: If ``True``, jit the evaluation of this Operator. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ @@ -72,14 +73,17 @@ def _eval(self, x: BlockArray) -> JaxArray: return convolve(x[0], x[1], mode=self.mode) def freeze(self, argnum: int, val: JaxArray) -> LinearOperator: - """Returns a new :class:`.LinearOperator` with block argument `argnum` fixed to value `val`. + """Freeze the `argnum` parameter. + + Return a new :class:`.LinearOperator` with block argument + `argnum` fixed to value `val`. If ``argnum == 0``, a :class:`.ConvolveByX` object is returned. If ``argnum == 1``, a :class:`.Convolve` object is returned. Args: - argnum: Index of block to freeze. Must be 0 or 1. - val: Value to fix the `argnum`-th input to. + argnum: Index of block to freeze. Must be 0 or 1. + val: Value to fix the `argnum`-th input to. """ if argnum == 0: diff --git a/scico/pgm.py b/scico/pgm.py index 741212ed6..d39299679 100644 --- a/scico/pgm.py +++ b/scico/pgm.py @@ -152,11 +152,11 @@ class AdaptiveBBStepSize(PGMStepSize): .. math:: - \alpha = \left\{ \begin{matrix} \alpha^{\mathrm{BB2}} \;, & + \alpha = \left\{ \begin{matrix} \alpha^{\mathrm{BB2}} & \mathrm{~if~} \alpha^{\mathrm{BB2}} / \alpha^{\mathrm{BB1}} < \kappa \; \\ - \alpha^{\mathrm{BB1}} \;, & \mathrm{~otherwise} \end{matrix} - \right . \;\;, + \alpha^{\mathrm{BB1}} & \mathrm{~otherwise} \end{matrix} + \right . \;, with :math:`\kappa \in (0, 1)`. @@ -244,7 +244,7 @@ class LineSearchStepSize(PGMStepSize): .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} - \right\|_2^2 \;\;, + \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the current solution or current extrapolation (if accelerated PGM). @@ -306,7 +306,8 @@ class RobustLineSearchStepSize(LineSearchStepSize): .. math:: \hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H - (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;\;, + (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} + \right\|_2^2 \;, with :math:`\mb{x}` the potential new update and :math:`\mb{y}` the auxiliary extrapolation state. @@ -373,7 +374,8 @@ class PGM: Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`. - The function :math:`f` must be smooth and :math:`g` must have a defined prox. + The function :math:`f` must be smooth and :math:`g` must have a + defined prox. Uses helper :class:`StepSize` to provide an estimate of the Lipschitz constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the @@ -398,15 +400,21 @@ def __init__( g: Instance of Functional with defined prox method L0: Initial estimate of Lipschitz constant of f x0: Starting point for :math:`\mb{x}` - step_size: helper :class:`StepSize` to estimate the Lipschitz constant of f - maxiter: Maximum number of PGM iterations to perform. Default: 100. - verbose: Flag indicating whether iteration statistics should be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` is a dict suitable - for passing to the `fields` argument of the :class:`.diagnostics.IterationStats` - initializer, and `insertfunc` is a function with two parameters, an integer - and a PGM object, responsible for constructing a tuple ready for insertion into - the :class:`.diagnostics.IterationStats` object. If None, default values are - used for the tuple components. + step_size: helper :class:`StepSize` to estimate the Lipschitz + constant of f + maxiter: Maximum number of PGM iterations to perform. + Default: 100. + verbose: Flag indicating whether iteration statistics should + be displayed. + itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` + is a dict suitable + for passing to the `fields` argument of the + :class:`.diagnostics.IterationStats` initializer, and + `insertfunc` is a function with two parameters, an + integer and a PGM object, responsible for constructing a + tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If None, + default values are used for the tuple components. """ if f.is_smooth is not True: @@ -470,9 +478,10 @@ def objective(self, x) -> float: def f_quad_approx(self, x, y, L) -> float: r"""Evaluate the quadratic approximation to function :math:`f`. - Evaluate the quadratic approximation to function :math:`f`, corresponding to - :math:`\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) - + \frac{L}{2} \left\|\mb{x} - \mb{y}\right\|_2^2`. + Evaluate the quadratic approximation to function :math:`f`, + corresponding to :math:`\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\|\mb{x} + - \mb{y}\right\|_2^2`. """ diff_xy = x - y return ( @@ -482,7 +491,11 @@ def f_quad_approx(self, x, y, L) -> float: ) def norm_residual(self) -> float: - r"""Return the fixed point residual (see Sec. 4.3 of :cite:`liu-2018-first`).""" + r"""Return the fixed point residual. + + Return the fixed point residual (see Sec. 4.3 of + :cite:`liu-2018-first`) + """ return self.fixed_point_residual def step(self): @@ -502,8 +515,9 @@ def solve( Run the PGM algorithm for a total of ``self.maxiter`` iterations. Args: - callback: An optional callback function, taking an a single argument of type - :class:`PGM`, that is called at the end of every iteration. + callback: An optional callback function, taking an a single + argument of type :class:`PGM`, that is called at the end + of every iteration. Returns: Computed solution. @@ -526,9 +540,9 @@ class AcceleratedPGM(PGM): Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`. - The function :math:`f` must be smooth and :math:`g` must have a defined prox. - - The accelerated form of PGM is also known as FISTA :cite:`beck-2009-fast`. + The function :math:`f` must be smooth and :math:`g` must have a + defined prox. The accelerated form of PGM is also known as FISTA + :cite:`beck-2009-fast`. For documentation on inherited attributes, see :class:`.PGM`. """ @@ -551,15 +565,20 @@ def __init__( g: Instance of Functional with defined prox method L0: Initial estimate of Lipschitz constant of f x0: Starting point for :math:`\mb{x}` - step_size: helper :class:`StepSize` to estimate the Lipschitz constant of f - maxiter: Maximum number of PGM iterations to perform. Default: 100. - verbose: Flag indicating whether iteration statistics should be displayed. - itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` is a dict suitable - for passing to the `fields` argument of the :class:`.diagnostics.IterationStats` - initializer, and `insertfunc` is a function with two parameters, an integer - and a PGM object, responsible for constructing a tuple ready for insertion into - the :class:`.diagnostics.IterationStats` object. If None, default values are - used for the tuple components. + step_size: helper :class:`StepSize` to estimate the Lipschitz + constant of f + maxiter: Maximum number of PGM iterations to perform. + Default: 100. + verbose: Flag indicating whether iteration statistics should + be displayed. + itstat: A tuple (`fieldspec`, `insertfunc`), where `fieldspec` + is a dict suitable for passing to the `fields` argument + of the :class:`.diagnostics.IterationStats` initializer, + and `insertfunc` is a function with two parameters, an + integer and a PGM object, responsible for constructing a + tuple ready for insertion into the + :class:`.diagnostics.IterationStats` object. If None, + default values are used for the tuple components. """ x0 = ensure_on_device(x0) super().__init__( diff --git a/scico/plot.py b/scico/plot.py index e57931be6..8c50d83cf 100644 --- a/scico/plot.py +++ b/scico/plot.py @@ -7,7 +7,8 @@ """Plotting/visualization functions. -Optional alternative high-level interface to selected :mod:`matplotlib` plotting functions. +Optional alternative high-level interface to selected :mod:`matplotlib` +plotting functions. """ # This module is copied from https://github.com/bwohlberg/sporco @@ -47,23 +48,25 @@ def _attach_keypress(fig, scaling=1.1): - """ - Attach a key press event handler that configures keys for closing a figure - and changing the figure size. Keys 'e' and 'c' respectively expand and - contract the figure, and key 'q' closes it. + """Attach a key press event handler. + + Attach a key press event handler that configures keys for closing a + figure and changing the figure size. Keys 'e' and 'c' respectively + expand and contract the figure, and key 'q' closes it. - **Note:** Resizing may not function correctly with all matplotlib backends - (a `bug `__ has been - reported). + **Note:** Resizing may not function correctly with all matplotlib + backends + (a `bug `__ + has been reported). Args: - fig (:class:`matplotlib.figure.Figure` object): Figure to which event - handling is to be attached - scaling (float, optional (default 1.1)): Scaling factor for figure - size changes + fig (:class:`matplotlib.figure.Figure` object): Figure to which + event handling is to be attached. + scaling (float, optional (default 1.1)): Scaling factor for + figure size changes. Returns: - function: Key press event handler function + function: Key press event handler function. """ def press(event): @@ -83,18 +86,19 @@ def press(event): def _attach_zoom(ax, scaling=2.0): - """ - Attach an event handler that supports zooming within a plot using the mouse - scroll wheel. + """Attach a scroll wheel event handler. + + Attach an event handler that supports zooming within a plot using the + mouse scroll wheel. Args: - ax (:class:`matplotlib.axes.Axes` object): Axes to which event handling - is to be attached - scaling (float, optional (default 2.0)): Scaling factor for zooming in - and out + ax (:class:`matplotlib.axes.Axes` object): Axes to which event + handling is to be attached. + scaling (float, optional (default 2.0)): Scaling factor for + zooming in and out. Returns: - function: Mouse scroll wheel event handler function + function: Mouse scroll wheel event handler function. """ # See https://stackoverflow.com/questions/11551049 @@ -176,58 +180,60 @@ def zoom(event): def plot(y, x=None, ptyp="plot", xlbl=None, ylbl=None, title=None, lgnd=None, lglc=None, **kwargs): - """ - Plot points or lines in 2D. If a figure object is specified then the plot - is drawn in that figure, and ``fig.show()`` is not called. The figure is - closed on key entry 'q'. + """Plot points or lines in 2D. + + Plot points or lines in 2D. If a figure object is specified then the + plot is drawn in that figure, and ``fig.show()`` is not called. The + figure is closed on key entry 'q'. Args: - y (array_like): 1d or 2d array of data to plot. If a 2d array, each - column is plotted as a separate curve. - x (array_like, optional (default None)): Values for x-axis of the plot + y (array_like): 1d or 2d array of data to plot. If a 2d array, + each column is plotted as a separate curve. + x (array_like, optional (default None)): Values for x-axis of the + plot. ptyp (string, optional (default 'plot')): Plot type specification - (options are 'plot', 'semilogx', 'semilogy', and 'loglog') - xlbl (string, optional (default None)): Label for x-axis - ylbl (string, optional (default None)): Label for y-axis - title (string, optional (default None)): Figure title - lgnd (list of strings, optional (default None)): List of legend string - lglc (string, optional (default None)): Legend location string + (options are 'plot', 'semilogx', 'semilogy', and 'loglog'). + xlbl (string, optional (default None)): Label for x-axis. + ylbl (string, optional (default None)): Label for y-axis. + title (string, optional (default None)): Figure title. + lgnd (list of strings, optional (default None)): List of legend + string. + lglc (string, optional (default None)): Legend location string. **kwargs: :class:`matplotlib.lines.Line2D` properties or figure - properties + properties. Keyword arguments specifying :class:`matplotlib.lines.Line2D` - properties, e.g. ``lw=2.0`` sets a line width of 2, or properties - of the figure and axes. If not specified, the defaults for line - width (``lw``) and marker size (``ms``) are 1.5 and 6.0 - respectively. The valid figure and axes keyword arguments are - listed below: + properties, e.g. ``lw=2.0`` sets a line width of 2, or + properties of the figure and axes. If not specified, the + defaults for line width (``lw``) and marker size (``ms``) are + 1.5 and 6.0 respectively. The valid figure and axes keyword + arguments are listed below: .. |mplfg| replace:: :class:`matplotlib.figure.Figure` object .. |mplax| replace:: :class:`matplotlib.axes.Axes` object .. rst-class:: kwargs - ===== ==================== ====================================== + ===== ==================== =================================== kwarg Accepts Description - ===== ==================== ====================================== + ===== ==================== =================================== fgsz tuple (width,height) Specify figure dimensions in inches fgnm integer Figure number of figure fig |mplfg| Draw in specified figure instead of creating one ax |mplax| Plot in specified axes instead of current axes of figure - ===== ==================== ====================================== + ===== ==================== =================================== Returns: - tuple: a tuple (fig, ax) containing: - - - **fig** (:class:`matplotlib.figure.Figure` object): Figure object - for this figure - - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for - this plot + - **fig** (:class:`matplotlib.figure.Figure` object): + Figure object for this figure. + - **ax** (:class:`matplotlib.axes.Axes` object): + Axes object for this plot. Raises: - ValueError: If an invalid plot type is specified via parameter `ptyp` + ValueError: If an invalid plot type is specified via parameter + `ptyp`. """ # Extract kwargs entries that are not related to line properties @@ -301,44 +307,46 @@ def surf( fig=None, ax=None, ): - """ - Plot a 2D surface in 3D. If a figure object is specified then the surface - is drawn in that figure, and ``fig.show()`` is not called. The figure is - closed on key entry 'q'. + """Plot a 2D surface in 3D. + + Plot a 2D surface in 3D. If a figure object is specified then the + surface is drawn in that figure, and ``fig.show()`` is not called. + The figure is closed on key entry 'q'. Args: - z (array_like): 2d array of data to plot - x (array_like, optional (default None)): Values for x-axis of the plot - y (array_like, optional (default None)): Values for y-axis of the plot - elev (float): Elevation angle (in degrees) in the z plane - azim (foat): Azimuth angle (in degrees) in the x,y plane - xlbl (string, optional (default None)): Label for x-axis - ylbl (string, optional (default None)): Label for y-axis - zlbl (string, optional (default None)): Label for z-axis - title (string, optional (default None)): Figure title - lblpad (float, optional (default 8.0)): Label padding - alpha (float between 0.0 and 1.0, optional (default 1.0)): Transparency - cntr (int or sequence of ints, optional (default None)): If not None, - plot contours of the surface on the lower end of the z-axis. An int - specifies the number of contours to plot, and a sequence specifies - the specific contour levels to plot. + z (array_like): 2d array of data to plot. + x (array_like, optional (default None)): Values for x-axis of the + plot. + y (array_like, optional (default None)): Values for y-axis of the + plot. + elev (float): Elevation angle (in degrees) in the z plane. + azim (foat): Azimuth angle (in degrees) in the x,y plane. + xlbl (string, optional (default None)): Label for x-axis. + ylbl (string, optional (default None)): Label for y-axis. + zlbl (string, optional (default None)): Label for z-axis. + title (string, optional (default None)): Figure title. + lblpad (float, optional (default 8.0)): Label padding. + alpha (float between 0.0 and 1.0, optional (default 1.0)): + Transparency. + cntr (int or sequence of ints, optional (default None)): If not + None, plot contours of the surface on the lower end of the + z-axis. An int specifies the number of contours to plot, and + a sequence specifies the specific contour levels to plot. cmap (:class:`matplotlib.colors.Colormap` object, optional (default None)): - Colour map for surface. If none specifed, defaults to cm.YlOrRd - fgsz (tuple (width,height), optional (default None)): Specify figure - dimensions in inches - fgnm (integer, optional (default None)): Figure number of figure + Colour map for surface. If none specifed, defaults to ``cm.YlOrRd``. + fgsz (tuple (width,height), optional (default None)): Specify + figure dimensions in inches. + fgnm (integer, optional (default None)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default None)): - Draw in specified figure instead of creating one + Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default None)): - Plot in specified axes instead of creating one + Plot in specified axes instead of creating one. Returns: - tuple: a tuple (fig, ax) containing: - - - **fig** (:class:`matplotlib.figure.Figure` object): Figure object - for this figure - - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for - this plot + - **fig** (:class:`matplotlib.figure.Figure` object): + Figure object for this figure. + - **ax** (:class:`matplotlib.axes.Axes` object): + Axes object for this plot. """ figp = fig @@ -418,50 +426,54 @@ def contour( fig=None, ax=None, ): - """ - Contour plot of a 2D surface. If a figure object is specified then the - plot is drawn in that figure, and ``fig.show()`` is not called. The - figure is closed on key entry 'q'. + """Contour plot of a 2D surface. + + Contour plot of a 2D surface. If a figure object is specified then + the plot is drawn in that figure, and ``fig.show()`` is not called. + The figure is closed on key entry 'q'. Args: - z (array_like): 2d array of data to plot - x (array_like, optional (default None)): Values for x-axis of the plot - y (array_like, optional (default None)): Values for y-axis of the plot - v (int or sequence of floats, optional (default 5)): An int specifies - the number of contours to plot, and a sequence specifies the - specific contour levels to plot. - xlog (boolean, optional (default False)): Set x-axis to log scale - ylog (boolean, optional (default False)): Set y-axis to log scale - xlbl (string, optional (default None)): Label for x-axis - ylbl (string, optional (default None)): Label for y-axis - title (string, optional (default None)): Figure title - cfmt (string, optional (default None)): Format string for contour labels. - cfntsz (int or None, optional (default 10)): Contour label font size. - No contour labels are displayed if set to 0 or None. - lfntsz (int, optional (default None)): Axis label font size. The default - font size is used if set to None. - alpha (float, optional (default 1.0)): Underlying image display alpha - value + z (array_like): 2d array of data to plot. + x (array_like, optional (default None)): Values for x-axis of the + plot. + y (array_like, optional (default None)): Values for y-axis of the + plot. + v (int or sequence of floats, optional (default 5)): An int + specifies the number of contours to plot, and a sequence + specifies the specific contour levels to plot. + xlog (boolean, optional (default False)): Set x-axis to log + scale. + ylog (boolean, optional (default False)): Set y-axis to log + scale. + xlbl (string, optional (default None)): Label for x-axis. + ylbl (string, optional (default None)): Label for y-axis. + title (string, optional (default None)): Figure title. + cfmt (string, optional (default None)): Format string for contour + labels. + cfntsz (int or None, optional (default 10)): Contour label font + size. No contour labels are displayed if set to 0 or None. + lfntsz (int, optional (default None)): Axis label font size. The + default font size is used if set to None. + alpha (float, optional (default 1.0)): Underlying image display + alpha value. cmap (:class:`matplotlib.colors.Colormap`, optional (default None)): - Colour map for surface. If none specifed, defaults to cm.YlOrRd - vmin, vmax (float, optional (default None)): Set upper and lower bounds - for the colour map (see the corresponding parameters of - :meth:`matplotlib.axes.Axes.imshow`) - fgsz (tuple (width,height), optional (default None)): Specify figure - dimensions in inches - fgnm (integer, optional (default None)): Figure number of figure + Colour map for surface. If none specifed, defaults to ``cm.YlOrRd``. + vmin, vmax (float, optional (default None)): Set upper and lower + bounds for the colour map (see the corresponding parameters + of :meth:`matplotlib.axes.Axes.imshow`). + fgsz (tuple (width,height), optional (default None)): Specify + figure dimensions in inches. + fgnm (integer, optional (default None)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default None)): - Draw in specified figure instead of creating one + Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default None)): - Plot in specified axes instead of current axes of figure + Plot in specified axes instead of current axes of figure. Returns: - tuple: a tuple (fig, ax) containing: - - - **fig** (:class:`matplotlib.figure.Figure` object): Figure object - for this figure - - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for - this plot + - **fig** (:class:`matplotlib.figure.Figure` object): + Figure object for this figure. + - **ax** (:class:`matplotlib.axes.Axes` object): + Axes object for this plot. """ figp = fig @@ -556,52 +568,52 @@ def imview( fig=None, ax=None, ): - """ - Display an image. Pixel values are displayed when the pointer is over valid - image data. If a figure object is specified then the image is drawn in that - figure, and ``fig.show()`` is not called. The figure is closed on key - entry 'q'. + """Display an image. + + Display an image. Pixel values are displayed when the pointer is over + valid image data. If a figure object is specified then the image is + drawn in that figure, and ``fig.show()`` is not called. The figure is + closed on key entry 'q'. Args: - img (array_like, shape (Nr, Nc) or (Nr, Nc, 3) or (Nr, Nc, 4)): Image - to display - title (string, optional (default None)): Figure title - copy (boolean, optional (default True)): If True, create a copy of - input `img` as a reference for displayed pixel values, ensuring - that displayed values do not change when the array changes in the - calling scope. Set this flag to False if the overhead of an - additional copy of the input image is not acceptable. - fltscl (boolean, optional (default False)): If True, rescale and shift - floating point arrays to [0,1] + img (array_like, shape (Nr, Nc) or (Nr, Nc, 3) or (Nr, Nc, 4)): + Image to display. + title (string, optional (default None)): Figure title. + copy (boolean, optional (default True)): If True, create a copy + of input `img` as a reference for displayed pixel values, + ensuring that displayed values do not change when the array + changes in the calling scope. Set this flag to False if the + overhead of an additional copy of the input image is not + acceptable. + fltscl (boolean, optional (default False)): If True, rescale and + shift floating point arrays to [0,1]. intrp (string, optional (default 'nearest')): Specify type of interpolation used to display image (see ``interpolation`` - parameter of :meth:`matplotlib.axes.Axes.imshow`) + parameter of :meth:`matplotlib.axes.Axes.imshow`). norm (:class:`matplotlib.colors.Normalize` object, optional (default None)): - Specify the :class:`matplotlib.colors.Normalize` instance used to - scale pixel values for input to the colour map - cbar (boolean, optional (default False)): Flag indicating whether to - display colorbar + Specify the :class:`matplotlib.colors.Normalize` instance + used to scale pixel values for input to the colour map. + cbar (boolean, optional (default False)): Flag indicating whether + to display colorbar. cmap (:class:`matplotlib.colors.Colormap`, optional (default None)): - Colour map for image. If none specifed, defaults to cm.Greys_r - for monochrome image - fgsz (tuple (width,height), optional (default None)): Specify figure - dimensions in inches - fgnm (integer, optional (default None)): Figure number of figure + Colour map for image. If none specifed, defaults to + ``cm.Greys_r`` for monochrome image. + fgsz (tuple (width,height), optional (default None)): Specify + figure dimensions in inches. + fgnm (integer, optional (default None)): Figure number of figure. fig (:class:`matplotlib.figure.Figure` object, optional (default None)): - Draw in specified figure instead of creating one + Draw in specified figure instead of creating one. ax (:class:`matplotlib.axes.Axes` object, optional (default None)): - Plot in specified axes instead of current axes of figure + Plot in specified axes instead of current axes of figure. Returns: - tuple: a tuple (fig, ax) containing: - - - **fig** (:class:`matplotlib.figure.Figure` object): Figure object - for this figure - - **ax** (:class:`matplotlib.axes.Axes` object): Axes object for - this plot + - **fig** (:class:`matplotlib.figure.Figure` object): + Figure object for this figure. + - **ax** (:class:`matplotlib.axes.Axes` object): + Axes object for this plot. Raises: - ValueError: Description + ValueError: If the input array is not of the required shape. """ if img.ndim > 2 and img.shape[2] != 3: @@ -716,13 +728,14 @@ def mouse_move_patch(arg): def close(fig=None): - """ - Close figure(s). If a figure object reference or figure number is provided, - close the specified figure, otherwise close all figures. + """Close figure(s). + + Close figure(s). If a figure object reference or figure number is + provided, close the specified figure, otherwise close all figures. Args: fig (:class:`matplotlib.figure.Figure` object or integer (optional (default None)): - Figure object or number of figure to close + Figure object or number of figure to close. """ if fig is None: @@ -735,7 +748,7 @@ def _in_ipython(): """Determine whether code is running in an ipython shell. Returns: - bool: True if running in an ipython shell, False otherwise + bool: True if running in an ipython shell, False otherwise. """ try: @@ -750,7 +763,7 @@ def _in_notebook(): """Determine whether code is running in a Jupyter Notebook shell. Returns: - bool: True if running in a notebook shell, False otherwise + bool: True if running in a notebook shell, False otherwise. """ try: @@ -762,15 +775,18 @@ def _in_notebook(): def set_ipython_plot_backend(backend="qt"): - """Set matplotlib backend within an ipython shell. Ths function has the - same effect as the line magic ``%matplotlib [backend]`` but is called as a - function and includes a check to determine whether the code is running in - an ipython shell, so that it can safely be used within a normal python - script since it has no effect when not running in an ipython shell. + """Set matplotlib backend within an ipython shell. + + Set matplotlib backend within an ipython shell. This function has the + same effect as the line magic ``%matplotlib [backend]`` but is called + as a function and includes a check to determine whether the code is + running in an ipython shell, so that it can safely be used within a + normal python script since it has no effect when not running in an + ipython shell. Args: - backend (string, optional (default 'qt')): Name of backend to be passed - to the ``%matplotlib`` line magic command + backend (string, optional (default 'qt')): Name of backend to be + passed to the ``%matplotlib`` line magic command. """ if _in_ipython(): @@ -779,16 +795,18 @@ def set_ipython_plot_backend(backend="qt"): def set_notebook_plot_backend(backend="inline"): - """Set matplotlib backend within a Jupyter Notebook shell. Ths function has - the same effect as the line magic ``%matplotlib [backend]`` but is called - as a function and includes a check to determine whether the code is running - in a notebook shell, so that it can safely be used within a normal python - script since it has no effect when not running in a notebook shell. + """Set matplotlib backend within a Jupyter Notebook shell. - Args: - backend (string, optional (default 'inline')): Name of backend to be - passed to the ``%matplotlib`` line magic command + Set matplotlib backend within a Jupyter Notebook shell. This function + has the same effect as the line magic ``%matplotlib [backend]`` but + is called as a function and includes a check to determine whether the + code is running in a notebook shell, so that it can safely be used + within a normal python script since it has no effect when not running + in a notebook shell. + Args: + backend (string, optional (default 'inline')): Name of backend to + be passed to the ``%matplotlib`` line magic command. """ if _in_notebook(): @@ -797,10 +815,12 @@ def set_notebook_plot_backend(backend="inline"): def config_notebook_plotting(): - """Configure plotting functions for inline plotting within a Jupyter - Notebook shell. This function has no effect when not within a notebook - shell, and may therefore be used within a normal python script. + """Configure plotting functions for inline plotting. + Configure plotting functions for inline plotting within a Jupyter + Notebook shell. This function has no effect when not within a + notebook shell, and may therefore be used within a normal python + script. """ # Check whether running within a notebook shell and have diff --git a/scico/random.py b/scico/random.py index a18f07a9a..4ea5f716a 100644 --- a/scico/random.py +++ b/scico/random.py @@ -8,9 +8,9 @@ """Random number generation. This module provides convenient wrappers around several `jax.random -`_ routines to handle -the generation and splitting of PRNG keys, as well as the generation of random -:class:`.BlockArray`. +`_ routines to +handle the generation and splitting of PRNG keys, as well as the +generation of random :class:`.BlockArray`: :: @@ -24,9 +24,9 @@ y, key = scico.random.randn((2,), key=key) print(y) # [ 0.00870693 -0.04888531] -The user is responsible for passing the PRNG key to :mod:`scico.random` functions. -If no key is passed, repeated calls to :mod:`scico.random` functions will return the same -random numbers: +The user is responsible for passing the PRNG key to :mod:`scico.random` +functions. If no key is passed, repeated calls to :mod:`scico.random` +functions will return the same random numbers: :: @@ -38,7 +38,8 @@ print(y) # [ 0.19307713 -0.52678305] -If the desired shape is a tuple containing tuples, a :class:`.BlockArray` is returned: +If the desired shape is a tuple containing tuples, a :class:`.BlockArray` +is returned: :: @@ -66,17 +67,19 @@ def _add_seed(fun): """ - Modifies a jax.random function to add a `seed` argument. + Modify a :mod:`jax.random` function to add a `seed` argument. Args: - fun: function to be modified, e.g., jax.random.normal. Expects `key` - to be the first argument. + fun: function to be modified, e.g., :func:`jax.random.normal`. + Expects `key` to be the first argument. Returns: - fun_alt: a version of `fun` supporting an optional `seed` argument that - is used to create a `jax.random.PRNGKey` that is passed along as the `key`. - The `key` argument may still be used, but is moved to be second-to-last. - By default, `seed=0`. The `seed` argument is added last. Other arguments are unchanged. + fun_alt: a version of `fun` supporting an optional `seed` + argument that is used to create a :func:`jax.random.PRNGKey` + that is passed along as the `key`. The `key` argument may + still be used, but is moved to be second-to-last. By default, + `seed=0`. The `seed` argument is added last. Other arguments + are unchanged. """ # find number of arguments to fun @@ -122,7 +125,7 @@ def fun_alt(*args, key=None, seed=None, **kwargs): def _allow_block_shape(fun): """ - Decorates a jax.random function so that the `shape` argument may be a BlockShape. + Decorate a jax.random function so that the `shape` argument may be a BlockShape. """ # use inspect to find which argument number is `shape` @@ -186,21 +189,25 @@ def randn( key: Optional[PRNGKey] = None, seed: Optional[int] = None, ) -> Tuple[Union[JaxArray, BlockArray], PRNGKey]: - """Return an array drawn from the standard normal distribution. Alias for :func:`scico.random.normal`. + """Return an array drawn from the standard normal distribution. + + Alias for :func:`scico.random.normal`. Args: - shape: Shape of output array. If shape is a tuple, a DeviceArray is returned. - If shape is a tuple of tuples, a :class:`.BlockArray` is returned. - key: JAX PRNGKey. Defaults to None, in which case a new key - is created using the seed arg. - seed: Seed for new PRNGKey. Default: 0 - dtype: dtype for returned value. Default to float32. If np.complex64, - generates an array sampled from complex normal distribution. + shape: Shape of output array. If shape is a tuple, a + DeviceArray is returned. If shape is a tuple of tuples, a + :class:`.BlockArray` is returned. + key: JAX PRNGKey. Defaults to None, in which case a new key + is created using the seed arg. + seed: Seed for new PRNGKey. Default: 0. + dtype: dtype for returned value. Default to ``np.float32``. + If ``np.complex64``, generates an array sampled from complex + normal distribution. Returns: tuple: A tuple (x, key) containing: - - **x** : (DeviceArray): Generated random array + - **x** : (DeviceArray): Generated random array. - **key** : Updated random PRNGKey. """ return normal(shape, dtype, key, seed) diff --git a/scico/scipy/special.py b/scico/scipy/special.py index 04905cc2f..4cfd578a8 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -7,7 +7,13 @@ """Wrapped versions of :mod:`jax.scipy.special` functions. -This modules consists of functions from :mod:`jax.scipy.special`. Some of these functions are wrapped to support compatibility with :class:`scico.blockarray.BlockArray` and are documented here. The remaining functions are imported directly from :mod:`jax.numpy`. While they can be imported from the :mod:`scico.numpy` namespace, they are not documented here; please consult the documentation for the source module :mod:`jax.scipy.special`. +This modules consists of functions from :mod:`jax.scipy.special`. Some of +these functions are wrapped to support compatibility with +:class:`scico.blockarray.BlockArray` and are documented here. The +remaining functions are imported directly from :mod:`jax.numpy`. While +they can be imported from the :mod:`scico.numpy` namespace, they are not +documented here; please consult the documentation for the source module +:mod:`jax.scipy.special`. """ __author__ = "Luke Pfister " diff --git a/scico/solver.py b/scico/solver.py index 03ffc6fcb..faaf2100a 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -27,14 +27,16 @@ def _wrap_func(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable: - """Computes function evaluation (without gradient) for use in :mod:`scipy.optimize` functions. + """Function evaluation for use in :mod:`scipy.optimize`. - Reshapes the input to ``func`` to have ``shape``. Evaluates ``func``. + Compute function evaluation (without gradient) for use in + :mod:`scipy.optimize` functions. Reshapes the input to ``func`` to + have ``shape``. Evaluates ``func``. Args: func: The function to minimize. shape: Shape of input to ``func``. - dtype: Data type of input to ``func`` + dtype: Data type of input to ``func``. """ val_func = jax.jit(func) @@ -53,15 +55,17 @@ def wrapper(x, *args): def _wrap_func_and_grad(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable: - """Computes function evaluation and gradient for use in :mod:`scipy.optimize` functions. + """Function evaluation and gradient for use in :mod:`scipy.optimize`. - Reshapes the input to ``func`` to have ``shape``. Evaluates ``func`` and computes gradient. - Ensures the returned ``grad`` is an ndarray. + Compute function evaluation and gradient for use in + :mod:`scipy.optimize` functions. Reshapes the input to ``func`` to + have ``shape``. Evaluates ``func`` and computes gradient. Ensures + the returned ``grad`` is an ndarray. Args: func: The function to minimize. shape: Shape of input to ``func``. - dtype: Data type of input to ``func`` + dtype: Data type of input to ``func``. """ # argnums=0 ensures only differentiate func wrt first argument, @@ -83,17 +87,18 @@ def wrapper(x, *args): def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: - """Splits an array of shape (N,M,...) into real and imaginary parts. + """Split an array of shape (N,M,...) into real and imaginary parts. Args: - x: Array to split. + x: Array to split. Returns: - A real ndarray with stacked real/imaginary parts. If ``x`` has shape - (M, N, ...), the returned array will have shape (2, M, N, ...) - where the first slice contains the ``x.real`` and the second contains - ``x.imag``. If `x` is a BlockArray, this function is called on each block - and the output is joined into a BlockArray. + A real ndarray with stacked real/imaginary parts. If ``x`` has + shape (M, N, ...), the returned array will have shape + (2, M, N, ...) where the first slice contains the ``x.real`` and + the second contains ``x.imag``. If `x` is a BlockArray, this + function is called on each block and the output is joined into a + BlockArray. """ if isinstance(x, BlockArray): return BlockArray.array([split_real_imag(_) for _ in x]) @@ -102,7 +107,18 @@ def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra def join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: - """Join a real array of shape (2,N,M,...) into a complex array of length (N,M, ...)""" + """Join a real array of shape (2,N,M,...) into a complex array. + + Join a real array of shape (2,N,M,...) into a complex array of length + (N,M, ...). + + Args: + x: Array to join. + + Returns: + A complex array with real and imaginary parts taken from ``x[0]`` + and ``x[1]`` respectively. + """ if isinstance(x, BlockArray): return BlockArray.array([join_real_imag(_) for _ in x]) else: @@ -127,18 +143,20 @@ def minimize( callback: Optional[Callable] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: - """Minimization of scalar function of one or more variables. Wrapper around - :func:`scipy.optimize.minimize`. + """Minimization of scalar function of one or more variables. - This function differs from :func:`scipy.optimize.minimize` in three ways: + Wrapper around :func:`scipy.optimize.minimize`. This function differs + from :func:`scipy.optimize.minimize` in three ways: - - The ``jac`` options of :func:`scipy.optimize.minimize` are not supported. The gradient is calculated using ``jax.grad``. - - Functions mapping from N-dimensional arrays -> float are supported + - The ``jac`` options of :func:`scipy.optimize.minimize` are not + supported. The gradient is calculated using ``jax.grad``. + - Functions mapping from N-dimensional arrays -> float are + supported. - Functions mapping from complex arrays -> float are supported. - Docstring for :func:`scipy.optimize.minimize` follows. For descriptions of - the optimization methods and custom minimizers, refer to the original - docstring for :func:`scipy.optimize.minimize`. + Docstring for :func:`scipy.optimize.minimize` follows. For + descriptions of the optimization methods and custom minimizers, refer + to the original docstring for :func:`scipy.optimize.minimize`. Args: func: The objective function to be minimized. @@ -335,11 +353,11 @@ def minimize_scalar( options: Optional[dict] = None, ) -> spopt.OptimizeResult: - """Minimization of scalar function of one variable. Wrapper around - :func:`scipy.optimize.minimize_scalar`. + """Minimization of scalar function of one variable. - Docstring for :func:`scipy.optimize.minimize_scalar` follows. - For descriptions of the optimization methods and custom minimizers, refer to the original + Wrapper around :func:`scipy.optimize.minimize_scalar`. Docstring for + :func:`scipy.optimize.minimize_scalar` follows. For descriptions of + the optimization methods and custom minimizers, refer to the original docstring for :func:`scipy.optimize.minimize_scalar`. Args: @@ -414,22 +432,23 @@ def cg( gradient method. Args: - A: Function implementing linear operator :math:`A` - b: Input array :math:`\mb{b}` - x0: Initial solution - tol: Relative residual stopping tolerance. Default: 1e-5 - Convergence occurs when ``norm(residual) <= max(tol * norm(b), atol)``. - atol : Absolute residual stopping tolerance. Default: 0.0 - Convergence occurs when ``norm(residual) <= max(tol * norm(b), atol)`` - maxiter: Maximum iterations. Default: 1000 - M: Preconditioner for A. The preconditioner should approximate the - inverse of ``A``. The default, ``None``, uses no preconditioner. + A: Function implementing linear operator :math:`A`. + b: Input array :math:`\mb{b}`. + x0: Initial solution. + tol: Relative residual stopping tolerance. Convergence occurs + when ``norm(residual) <= max(tol * norm(b), atol)``. + atol : Absolute residual stopping tolerance. Convergence occurs + when ``norm(residual) <= max(tol * norm(b), atol)``. + maxiter: Maximum iterations. Default: 1000. + M: Preconditioner for A. The preconditioner should approximate + the inverse of ``A``. The default, ``None``, does not use a + preconditioner. Returns: tuple: A tuple (x, info) containing: - - **x** : Solution array - - **info**: Dictionary containing diagnostic information + - **x** : Solution array. + - **info**: Dictionary containing diagnostic information. """ if M is None: diff --git a/scico/typing.py b/scico/typing.py index 53a8766c3..3d4eadc8a 100644 --- a/scico/typing.py +++ b/scico/typing.py @@ -18,21 +18,21 @@ JaxArray = Union[jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray] -"""A jax array""" +"""A jax array.""" Array = Union[np.ndarray, JaxArray] -"""Either a numpy or jax array""" +"""Either a numpy or jax array.""" PRNGKey = jnp.ndarray -"""A key for jax random number generators (see :mod:`jax.random`)""" +"""A key for jax random number generators (see :mod:`jax.random`).""" -DType = Any # TODO: can we do better than this? Maybe with the new numpy typing? -"""A numpy or jax dtype""" +DType = Any # TODO: can we do better than this? Maybe with the new numpy typing? +"""A numpy or jax dtype.""" -Shape = Tuple[int, ...] # Shape of an array -"""A shape of a numpy or jax array""" +Shape = Tuple[int, ...] # shape of an array +"""A shape of a numpy or jax array.""" -BlockShape = Tuple[Tuple[int, ...], ...] # Shape of a BlockArray -"""A shape of a :class:`.BlockArray`""" +BlockShape = Tuple[Tuple[int, ...], ...] # shape of a BlockArray +"""A shape of a :class:`.BlockArray`.""" Axes = Union[int, Tuple[int, ...]] # one or more axes diff --git a/scico/util.py b/scico/util.py index 645d735b0..81f4734dc 100644 --- a/scico/util.py +++ b/scico/util.py @@ -114,20 +114,20 @@ def url_get(url: str, maxtry: int = 3, timeout: int = 10) -> io.BytesIO: # prag """Get content of a file via a URL. Args: - url: URL of the file to be downloaded - maxtry: Maximum number of download retries. Default: 3. - timeout: Timeout in seconds for blocking operations. Default: 10 + url: URL of the file to be downloaded. + maxtry: Maximum number of download retries. + timeout: Timeout in seconds for blocking operations. Returns: - Buffered I/O stream + Buffered I/O stream. Raises: - ValueError: If the maxtry parameter is not greater than zero - urllib.error.URLError: If the file cannot be downloaded + ValueError: If the maxtry parameter is not greater than zero. + urllib.error.URLError: If the file cannot be downloaded. """ if maxtry <= 0: - raise ValueError("Parameter maxtry should be greater than zero") + raise ValueError("Parameter maxtry should be greater than zero.") for ntry in range(maxtry): try: rspns = urlrequest.urlopen(url, timeout=timeout) @@ -150,7 +150,7 @@ def parse_axes( Args: axes: user specification of one or more axes: int, list, tuple, - or ``None`` + or ``None``. shape: the shape of the array of which axes are being specified. If not ``None``, `axes` is checked to make sure its entries refer to axes that exist in `shape`. @@ -158,13 +158,13 @@ def parse_axes( default, `list(range(len(shape)))`. Returns: - List of axes (never an int, never ``None``) + List of axes (never an int, never ``None``). """ if axes is None: if default is None: if shape is None: - raise ValueError("`axes` cannot be `None` without a default or shape specified") + raise ValueError("`axes` cannot be `None` without a default or shape specified.") else: axes = list(range(len(shape))) else: @@ -177,10 +177,10 @@ def parse_axes( raise ValueError(f"Could not understand axes {axes} as a list of axes") if shape is not None and max(axes) >= len(shape): raise ValueError( - f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}" + f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." ) elif len(set(axes)) != len(axes): - raise ValueError("Duplicate vaue in axes {axes}; each axis must be unique") + raise ValueError("Duplicate vaue in axes {axes}; each axis must be unique.") return axes @@ -216,9 +216,9 @@ def check_for_tracer(func: Callable) -> Callable: """Check if positional arguments to ``func`` are jax tracers. This is intended to be used as a decorator for functions that call - external code from within SCICO. At present, external functions cannot - be jit-ed or vmap/pmaped. This decorator checks for signs of jit/vmap/pmap - and raises an appropriate exception. + external code from within SCICO. At present, external functions + cannot be jit-ed or vmap/pmaped. This decorator checks for signs of + jit/vmap/pmap and raises an appropriate exception. """ @wraps(func) @@ -254,9 +254,9 @@ def __init__( Args: labels: Label(s) of the timer(s) to be initialised to zero. default_label : Default timer label to be used when methods - are called without specifying a label + are called without specifying a label. all_label : Label string that will be used to denote all - timer labels + timer labels. """ # Initialise current and accumulated time dictionaries @@ -389,7 +389,7 @@ def elapsed(self, label: Optional[str] = None, total: bool = True) -> float: corresponding call to :meth:`stop`. Returns: - Elapsed time + Elapsed time. """ # Get current time @@ -418,7 +418,7 @@ def labels(self) -> List[str]: """Get a list of timer labels. Returns: - List of timer labels + List of timer labels. """ return self.t0.keys() @@ -428,10 +428,10 @@ def __str__(self) -> str: The representation consists of a table with the following columns: - * Timer label - * Accumulated time from past start/stop calls + * Timer label. + * Accumulated time from past start/stop calls. * Time since current start call, or 'Stopped' if timer is not - currently running + currently running. """ # Get current time @@ -484,8 +484,8 @@ def __init__( ): """ Args: - timer: Timer object to be used as a context manager. If ``None``, a - new class:`Timer` object is constructed. + timer: Timer object to be used as a context manager. If + ``None``, a new class:`Timer` object is constructed. label: Label of the timer to be used. If it is ``None``, start the default timer. action: Actions to be taken on context entry and exit. If the @@ -513,8 +513,8 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): - """Stop the timer and return True if no exception was raised within - the 'with' block, otherwise return False. + """Stop the timer and return True if no exception was raised + within the 'with' block, otherwise return False. """ if self.action == "StartStop": @@ -537,7 +537,7 @@ def elapsed(self, total: bool = True) -> float: corresponding call to :meth:`stop`. Returns: - Elapsed time + Elapsed time. """ return self.timer.elapsed(self.label, total=total)