diff --git a/data b/data index c86c573b8..cc21bc849 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit c86c573b8a3eee7b0676451c3275d3acc0a7a2d5 +Subproject commit cc21bc849e01986dab84050ab52c86ca1b6922af diff --git a/docs/source/conf.py b/docs/source/conf.py index db4003f00..8b67539a7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -395,7 +395,11 @@ class ExperimentAnalysis: snp_func = getmembers(scico.numpy, isfunction) for _, f in snp_func: - if f.__module__[0:14] == "jax._src.numpy" or f.__module__ == "scico.numpy._create": + if ( + f.__module__ == "scico.numpy" + or f.__module__[0:14] == "jax._src.numpy" + or f.__module__ == "scico.numpy._create" + ): # Rewrite module name so that function is included in docs f.__module__ = "scico.numpy" # Attempt to fix incorrect cross-reference @@ -403,6 +407,12 @@ class ExperimentAnalysis: modname = "numpy.char" else: modname = "numpy" + f.__doc__ = re.sub( + r"^:func:`([\w_]+)` wrapped to operate", + r":obj:`jax.numpy.\1` wrapped to operate", + str(f.__doc__), + flags=re.M, + ) f.__doc__ = re.sub( r"^LAX-backend implementation of :func:`([\w_]+)`.", r"LAX-backend implementation of :obj:`%s.\1`." % modname, @@ -429,6 +439,40 @@ class ExperimentAnalysis: scico.numpy.vectorize.__doc__ = re.sub("^ ", "", scico.numpy.vectorize.__doc__, flags=re.M) +# Similar processing for scico.scipy +import scico.scipy + +ssp_func = getmembers(scico.scipy.special, isfunction) +for _, f in ssp_func: + if f.__module__[0:11] == "scico.scipy" or f.__module__[0:14] == "jax._src.scipy": + # Attempt to fix incorrect cross-reference + f.__doc__ = re.sub( + r"^:func:`([\w_]+)` wrapped to operate", + r":obj:`jax.scipy.special.\1` wrapped to operate", + str(f.__doc__), + flags=re.M, + ) + modname = "scipy.special" + f.__doc__ = re.sub( + r"^LAX-backend implementation of :func:`([\w_]+)`.", + r"LAX-backend implementation of :obj:`%s.\1`." % modname, + str(f.__doc__), + flags=re.M, + ) + # Remove cross-reference to numpydoc style references section + f.__doc__ = re.sub(r" \[(\d+)\]_", "", f.__doc__, flags=re.M) + # Remove entire numpydoc references section + f.__doc__ = re.sub(r"References\n----------\n.*\n", "", f.__doc__, flags=re.DOTALL) + # Remove problematic citation + f.__doc__ = re.sub("See \[dlmf\]_ for details.", "", f.__doc__, re.M) + f.__doc__ = re.sub("\[dlmf\]_", "NIST DLMF", f.__doc__, re.M) + +# Fix indentation problems +scico.scipy.special.sph_harm.__doc__ = re.sub( + "^Computes the", " Computes the", scico.scipy.special.sph_harm.__doc__, flags=re.M +) + + def class_inherit_diagrams(_): # Insert inheritance diagrams for classes that have base classes import scico diff --git a/docs/source/style.rst b/docs/source/style.rst index 9011c2935..a1d348532 100644 --- a/docs/source/style.rst +++ b/docs/source/style.rst @@ -118,7 +118,7 @@ We follow the `Google string conventions `_. The usage of ``import`` statements should be reserved for packages and modules only excluding individual classes and functions. The only exception to this is the typing module. +We follow the `Google import conventions `_. The use of ``import`` statements should be reserved for packages and modules only, i.e. individual classes and functions should not be imported. The only exception to this is the typing module. - Use ``import x`` for importing packages and modules, where x is the package or module name. - Use ``from x import y`` where x is the package name and y is the module name. @@ -135,7 +135,7 @@ We follow the `Google variable typing conventions `_.: +The following components require the recommended markup taken from the `NumPy Conventions `__.: - Paragraphs: Indentation is significant and indicates the indentation of the output. New paragraphs are marked with a blank line. -- Variable, module, function, and class names: - Should be written in between single back-ticks (`x`). +- Variable, parameter, module, function, method, and class names: + Should be written between single back-ticks (e.g. \`x\`, rendered as `x`), but note that use of `Sphinx cross-reference syntax `_ is preferred for modules (`:mod:\`module-name\`` ), functions (`:func:\`function-name\`` ), methods (`:meth:\`method-name\`` ) and classes (`:class:\`class-name\`` ). - None, NoneType, True, and False: - Should be written in between double back-ticks (``None``, ``True``). + Should be written between double back-ticks (e.g. \`\`None\`\`, \`\`True\`\`, rendered as ``None``, ``True``). - Types: - Should be written in between double back-ticks (``int``). + Should be written between double back-ticks (e.g. \`\`int\`\`, rendered as ``int``). -Other components can use *italics*, **bold**, and ``monospace`` if needed, but not for variable names, doctest code, or multi-line code. +Other components can use \*italics\*, \*\*bold\*\*, and \`\`monospace\`\` (respectively rendered as *italics*, **bold**, and ``monospace``) if needed, but not for variable names, doctest code, or multi-line code. Documentation @@ -463,6 +463,8 @@ A few notable guidelines: * Avoid capitalization in text except where absolutely necessary, e.g., "Newton’s first law." + * Use a single space after the period at the end of a sentence. + The source code (`.rst` files) for these pages does not have a line-length guideline, but line breaks at or before 79 characters are encouraged. diff --git a/scico/array.py b/scico/array.py index 4fae1018a..195d763f9 100644 --- a/scico/array.py +++ b/scico/array.py @@ -21,7 +21,7 @@ import scico.blockarray import scico.numpy as snp -from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape +from scico.typing import ArrayIndex, Axes, AxisIndex, DType, JaxArray, Shape def ensure_on_device( @@ -74,8 +74,8 @@ def ensure_on_device( def no_nan_divide( - x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] -) -> Union[BlockArray, JaxArray]: + x: Union[scico.blockarray.BlockArray, JaxArray], y: Union[scico.blockarray.BlockArray, JaxArray] +) -> Union[scico.blockarray.BlockArray, JaxArray]: """Return `x/y`, with 0 instead of NaN where `y` is 0. Args: diff --git a/scico/blockarray.py b/scico/blockarray.py index 1e9b086e3..90594146e 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed diff --git a/scico/data/__init__.py b/scico/data/__init__.py index 9c769787e..4d86820ff 100644 --- a/scico/data/__init__.py +++ b/scico/data/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/examples.py b/scico/examples.py index dd7d585c4..eb40db9c1 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/flax.py b/scico/flax.py index f2d96ad16..02414fb6d 100644 --- a/scico/flax.py +++ b/scico/flax.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index d7171d5c2..01beef95d 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/functional/_flax.py b/scico/functional/_flax.py index 2a8adbd63..82e36e825 100644 --- a/scico/functional/_flax.py +++ b/scico/functional/_flax.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -17,8 +17,6 @@ PyTree = Any -__author__ = """Cristina Garcia-Cardona """ - class FlaxMap(Functional): r"""Functional whose prox applies a trained flax model.""" diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index cfddcb49b..76d52829f 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -16,10 +16,6 @@ from scico.blockarray import BlockArray from scico.typing import JaxArray -__author__ = """\n""".join( - ["Luke Pfister ", "Thilo Balke "] -) - class Functional: r"""Base class for functionals. diff --git a/scico/functional/_indicator.py b/scico/functional/_indicator.py index def48cca5..1c4d54773 100644 --- a/scico/functional/_indicator.py +++ b/scico/functional/_indicator.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -18,8 +18,6 @@ from ._functional import Functional -__author__ = """Luke Pfister """ - class NonNegativeIndicator(Functional): r"""Indicator function for non-negative orthant. @@ -50,17 +48,17 @@ def prox( self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: r"""Evaluate the scaled proximal operator of the indicator over - the non-negative orthant, :math:`I_{>= 0} `,: + the non-negative orthant, :math:`I_{>= 0}`, .. math:: [\mathrm{prox}_{\lambda I_{>=0}}(\mb{v})]_i = \begin{cases} - v_i, & \text{if } v_i \geq 0 \\ - 0, & \text{else}. + v_i\,, & \text{if } v_i \geq 0 \\ + 0\,, & \text{otherwise} \;. \end{cases} Args: - v : Input array :math:`\mb{v}`. + v : Input array :math:`\mb{v}`. lam : Proximal parameter :math:`\lambda` (has no effect). kwargs: Additional arguments that may be used by derived classes. diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 29c81c630..92a570a86 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -166,7 +166,7 @@ def prox( .. math:: \mathrm{prox}_{\lambda \| \cdot \|_2}(\mb{v}) - = \mb{v} \left(1 - \frac{\lambda}{\norm{v}_2} \right)_+ \;, + = \mb{v} \left(1 - \frac{\lambda}{\norm{v}_2} \right)_+ \;, where diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index e68c60bb5..a33f40d37 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 542bcc06c..ee8e0df77 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -23,8 +23,6 @@ from ._linop import LinearOperator from ._stack import LinearOperatorStack -__author__ = """Luke Pfister , Michael McCann """ - class FiniteDifference(LinearOperatorStack): """Finite Difference operator. diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 195869b0c..e829ca9e0 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index c22462890..95f8c4142 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index b261e287b..bd0816670 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 4b5c86dd1..c296b23f5 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,7 +13,6 @@ module is a work in progress and therefore not all functions are wrapped. Functions that have not been wrapped yet have WARNING text in their documentation, below. - """ import sys diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py index 3754ba477..8478d71d6 100644 --- a/scico/numpy/_create.py +++ b/scico/numpy/_create.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 39bcf8d21..af38cdb1e 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index f2bbde399..3c172e3e3 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the diff --git a/scico/scipy/special.py b/scico/scipy/special.py index b4913ab56..b65004cd7 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -16,7 +16,6 @@ :mod:`jax.scipy.special`. """ -__author__ = "Luke Pfister " import sys diff --git a/scico/solver.py b/scico/solver.py index a16c67e77..255af9b11 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -1,11 +1,60 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. -"""Optimization algorithms.""" +"""SciPy optimization algorithms. + +.. raw:: html + + + +This module provides scico interface wrappers for functions +from :mod:`scipy.optimize` since jax directly implements only a very +limited subset of these functions (there is limited, experimental support +for `L-BFGS-B `_), but only CG +and BFGS are fully supported. These wrappers are required because the +functions in :mod:`scipy.optimize` only support on 1D, real valued, numpy +arrays. These limitations are addressed by: + +- Enabling the use of multi-dimensional arrays by flattening and reshaping + within the wrapper. +- Enabling the use of jax arrays by automatically converting to and from + numpy arrays. +- Enabling the use of complex arrays by splitting them into real and + imaginary parts. + +The wrapper also JIT compiles the function and gradient evaluations. + +The functions provided in this module have a number of advantages and +disadvantages with respect to those in :mod:`jax.scipy.optimize`: + +- This module provides many more algorithms than + :mod:`jax.scipy.optimize`. +- The functions in this module tend to be faster for small-scale problems + (presumably due to some overhead in the jax functions). +- The functions in this module are slower for large problems due to the + frequent host-device copies corresponding to conversion between numpy + arrays and jax arrays. +- The solvers in this module can't be JIT compiled, and gradients cannot + be taken through them. + +In the future, this module may be replaced with a dependency on +`JAXopt `__. +""" from functools import wraps @@ -18,20 +67,18 @@ from scico.typing import BlockShape, DType, JaxArray, Shape from scipy import optimize as spopt -__author__ = """Luke Pfister """ - def _wrap_func(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable: """Function evaluation for use in :mod:`scipy.optimize`. Compute function evaluation (without gradient) for use in - :mod:`scipy.optimize` functions. Reshapes the input to ``func`` to - have ``shape``. Evaluates ``func``. + :mod:`scipy.optimize` functions. Reshapes the input to `func` to + have `shape`. Evaluates `func`. Args: func: The function to minimize. - shape: Shape of input to ``func``. - dtype: Data type of input to ``func``. + shape: Shape of input to `func`. + dtype: Data type of input to `func`. """ val_func = jax.jit(func) @@ -53,14 +100,14 @@ def _wrap_func_and_grad(func: Callable, shape: Union[Shape, BlockShape], dtype: """Function evaluation and gradient for use in :mod:`scipy.optimize`. Compute function evaluation and gradient for use in - :mod:`scipy.optimize` functions. Reshapes the input to ``func`` to - have ``shape``. Evaluates ``func`` and computes gradient. Ensures - the returned ``grad`` is an ndarray. + :mod:`scipy.optimize` functions. Reshapes the input to `func` to + have `shape`. Evaluates `func` and computes gradient. Ensures + the returned `grad` is an ndarray. Args: func: The function to minimize. - shape: Shape of input to ``func``. - dtype: Data type of input to ``func``. + shape: Shape of input to `func`. + dtype: Data type of input to `func`. """ # argnums=0 ensures only differentiate func wrt first argument, @@ -82,13 +129,13 @@ def wrapper(x, *args): def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: - """Split an array of shape (N,M,...) into real and imaginary parts. + """Split an array of shape (N, M, ...) into real and imaginary parts. Args: x: Array to split. Returns: - A real ndarray with stacked real/imaginary parts. If ``x`` has + A real ndarray with stacked real/imaginary parts. If `x` has shape (M, N, ...), the returned array will have shape (2, M, N, ...) where the first slice contains the ``x.real`` and the second contains ``x.imag``. If `x` is a BlockArray, this @@ -136,8 +183,8 @@ def minimize( Wrapper around :func:`scipy.optimize.minimize`. This function differs from :func:`scipy.optimize.minimize` in three ways: - - The ``jac`` options of :func:`scipy.optimize.minimize` are not - supported. The gradient is calculated using ``jax.grad``. + - The `jac` options of :func:`scipy.optimize.minimize` are not + supported. The gradient is calculated using `jax.grad`. - Functions mapping from N-dimensional arrays -> float are supported. - Functions mapping from complex arrays -> float are supported. @@ -189,11 +236,10 @@ def minimize( options=options, ) - # TODO: need tests for multi-gpu machines # un-vectorize the output array, put on device res.x = snp.reshape( res.x, x0_shape - ) # if x0 was originally a BlockArray be converted back to one here + ) # if x0 was originally a BlockArray then res.x is converted back to one here res.x = res.x.astype(x0_dtype) @@ -226,7 +272,7 @@ def minimize_scalar( """ def f(x, *args): - # Wrap jax-based function ``func`` to return a numpy float + # Wrap jax-based function `func` to return a numpy float # rather than a DeviceArray of size (1,) return func(x, *args).item() @@ -267,8 +313,8 @@ def cg( atol: Absolute residual stopping tolerance. Convergence occurs when ``norm(residual) <= max(tol * norm(b), atol)``. maxiter: Maximum iterations. Default: 1000. - M: Preconditioner for A. The preconditioner should approximate - the inverse of ``A``. The default, ``None``, uses no + M: Preconditioner for `A`. The preconditioner should approximate + the inverse of `A`. The default, ``None``, uses no preconditioner. Returns: