Skip to content

Commit

Permalink
Docstring cleanup and improvement of style guide compliance (#119)
Browse files Browse the repository at this point in the history
* Various cleanup, mainly docstrings

* Docstring cleanup

* Typo bug fix

* Use intersphinx refs

* Some improvements

* Remove explicit copy of jax docstrings

* Fix indentation issues

* Clean up docstrings

* Minor edit

* Style guide compliance

* Style guide compliance

* Cleanup

* Fix docstring format problem

* Cleanup and style compliance

* Apply black manually

* Docstring cleanup and style compliance

* Style compliance

* Improve formatting of returns specification

* Alternative formatting of multiple returns

* Formatting improvement

* Minor changes

* Format as code

Co-authored-by: Mike McCann <[email protected]>
  • Loading branch information
bwohlberg and Michael-T-McCann authored Dec 6, 2021
1 parent 3bdec67 commit 9438cf5
Show file tree
Hide file tree
Showing 40 changed files with 1,479 additions and 1,233 deletions.
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def patched_parse(self):
napoleon_use_ivar = True
napoleon_use_rtype = False

# See https://github.com/sphinx-doc/sphinx/issues/9119
# napoleon_custom_sections = [("Returns", "params_style")]


graphviz_output_format = "svg"
inheritance_graph_attrs = dict(rankdir="LR", fontsize=9, ratio="compress", bgcolor="transparent")
inheritance_node_attrs = dict(
Expand Down
16 changes: 11 additions & 5 deletions scico/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""SCICO -- a Python package for solving the inverse problems that arise in scientific imaging applications."""
"""Scientific Computational Imaging COde (SCICO) is a Python package for
solving the inverse problems that arise in scientific imaging applications.
"""

__version__ = "0.0.2a1"

Expand All @@ -16,13 +18,17 @@
# TODO remove this check as we get closer to release?
import jax, jaxlib

if jax.__version__ < "0.2.19":
jax_ver_req = "0.2.19"
jaxlib_ver_req = "0.1.70"
if jax.__version__ < jax_ver_req:
raise Exception(
f"""SCICO {__version__} requires jax>0.2.19; got {jax.__version__}; please upgrade jax."""
f"SCICO {__version__} requires jax>={jax_ver_req}; got {jax.__version__}; "
"please upgrade jax."
)
if jaxlib.__version__ < "0.1.70":
if jaxlib.__version__ < jaxlib_ver_req:
raise Exception(
f"""SCICO {__version__} requires jaxlib>0.1.70; got {jaxlib.__version__}; please upgrade jaxlib."""
f"SCICO {__version__} requires jaxlib>={jaxlib_ver_req}; got {jaxlib.__version__}; "
"please upgrade jaxlib."
)

from jax import custom_jvp, custom_vjp, jacfwd, jvp, linearize, vjp, hessian
Expand Down
178 changes: 45 additions & 133 deletions scico/_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Automatric differentiation tools."""
"""Automatic differentiation tools."""


from typing import Any, Callable, Sequence, Tuple, Union
Expand All @@ -13,45 +13,32 @@
import jax.numpy as jnp


def _append_jax_docs(fn, jaxfn=None):
"""Append the jax function docs.
Given wrapper function ``fn``, concatenate its docstring with the
docstring of the wrapped jax function.
"""

name = fn.__name__
if jaxfn is None:
jaxfn = getattr(jax, name)
doc = " " + fn.__doc__.replace("\n ", "\n ") # deal with indentation differences
jaxdoc = "\n".join(jaxfn.__doc__.split("\n")[2:]) # strip initial lines
return doc + f"\n Docstring for :func:`jax.{name}`:\n\n" + jaxdoc


def grad(
fun: Callable,
argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
) -> Callable:
"""Creates a function which evaluates the gradient of ``fun``.
:func:`scico.grad` differs from :func:`jax.grad` in that the output is conjugated.
Docstring for :func:`jax.grad`:
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers.
Argument arrays in the positions specified by ``argnums`` must be of
inexact (i.e., floating-point or complex) type. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
of ``fun``. If ``argnums`` is an integer then the gradient has the same
shape and type as the positional argument indicated by that integer. If
argnums is a tuple of integers, the gradient is a tuple of values with the
same shapes and types as the corresponding arguments. If ``has_aux`` is True
then a pair of (gradient, auxiliary_data) is returned.
"""Create a function that evaluates the gradient of ``fun``.
:func:`scico.grad` differs from :func:`jax.grad` in that the output
is conjugated.
"""

jax_grad = jax.grad(
Expand All @@ -69,43 +56,21 @@ def conjugated_grad(*args, **kwargs):
return conjugated_grad_aux if has_aux else conjugated_grad


# Append docstring from original jax function
grad.__doc__ = _append_jax_docs(grad)


def value_and_grad(
fun: Callable,
argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
:func:`scico.value_and_grad` differs from :func:`jax.value_and_grad` in that the gradient is conjugated.
Docstring for :func:`jax.value_and_grad`:
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
and the gradient of ``fun`` and returns them as a pair (a two-element
tuple). If ``argnums`` is an integer then the gradient has the same shape
and type as the positional argument indicated by that integer. If argnums is
a sequence of integers, the gradient is a tuple of values with the same
shapes and types as the corresponding arguments.
"""Create a function that evaluates both ``fun`` and its gradient.
:func:`scico.value_and_grad` differs from :func:`jax.value_and_grad`
in that the gradient is conjugated.
"""
jax_val_grad = jax.value_and_grad(
fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int
Expand All @@ -124,50 +89,17 @@ def conjugated_value_and_grad(*args, **kwargs):
return conjugated_value_and_grad_aux if has_aux else conjugated_value_and_grad


# Append docstring from original jax function
value_and_grad.__doc__ = _append_jax_docs(value_and_grad)


def linear_adjoint(fun: Callable, *primals) -> Callable:
"""Conjugate transpose a function that is promised to be linear.
"""Conjugate transpose a function that is guaranteed to be linear.
:func:`scico.linear_adjoint` differs from :func:`jax.linear_transpose`
for complex inputs in that the conjugate transpose (adjoint) of `fun` is returned.
:func:`scico.linear_adjoint` is identical to :func:`jax.linear_transpose`
for real-valued primals.
Docstring for :func:`jax.linear_transpose`:
For linear functions, this transformation is equivalent to ``vjp``, but
avoids the overhead of computing the forward pass.
The outputs of the transposed function will always have the exact same dtypes
as ``primals``, even if some values are truncated (e.g., from complex to
float, or from float64 to float32). To avoid truncation, use dtypes in
``primals`` that match the full range of desired outputs from the transposed
function. Integer dtypes are not supported.
Args:
fun: the linear function to be transposed.
*primals: a positional argument tuple of arrays, scalars, or (nested)
standard Python containers (tuples, lists, dicts, namedtuples, i.e.,
pytrees) of those types used for evaluating the shape/dtype of
``fun(*primals)``. These arguments may be real scalars/ndarrays, but that
is not required: only the ``shape`` and `dtype` attributes are accessed.
See below for an example. (Note that the duck-typed objects cannot be
namedtuples because those are treated as standard Python containers.)
Returns:
A callable that calculates the transpose of ``fun``. Valid input into this
function must have the same shape/dtypes/structure as the result of
``fun(*primals)``. Output will be a tuple, with the same
shape/dtypes/structure as ``primals``.
>>> import jax
>>> import types
>>> import numpy as np
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32)
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
for complex inputs in that the conjugate transpose (adjoint) of `fun`
is returned. :func:`scico.linear_adjoint` is identical to
:func:`jax.linear_transpose` for real-valued primals.
"""

def conj_fun(*primals):
Expand All @@ -189,6 +121,10 @@ def conj_fun(*primals):
return jax.linear_transpose(_fun, *_primals)


# Append docstring from original jax function
linear_adjoint.__doc__ = _append_jax_docs(linear_adjoint, jaxfn=jax.linear_transpose)


def jacrev(
fun: Callable,
argnums: Union[int, Sequence[int]] = 0,
Expand All @@ -197,36 +133,8 @@ def jacrev(
) -> Callable:
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
:func:`scico.jacrev` differs from :func:`jax.jacrev` in that the output is conjugated.
Docstring for :func:`jax.jacrev`:
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) # doctest: +SKIP
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
[ 1.6209 0. 0.84147]]
:func:`scico.jacrev` differs from :func:`jax.jacrev` in that the
output is conjugated.
"""

jax_jacrev = jax.jacrev(fun=fun, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int)
Expand All @@ -236,3 +144,7 @@ def conjugated_jacrev(*args, **kwargs):
return jax.tree_map(jax.numpy.conj, tmp)

return conjugated_jacrev


# Append docstring from original jax function
jacrev.__doc__ = _append_jax_docs(jacrev)
Loading

0 comments on commit 9438cf5

Please sign in to comment.