Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mike's docs readthrough #53

Merged
merged 7 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions INSTALL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ Automatic Installation

The recommended way to install SCICO and its dependencies is via `conda <https://docs.conda.io/en/latest/>`_ using the scripts in ``misc/conda``:

- ``install_conda.sh``: install ``miniconda`` (needed if conda is not already installed on your system)
- ``conda_env.sh``: install a ``conda`` environment with all SCICO dependencies
- ``install_conda.sh``: install ``miniconda``
(needed if conda is not already installed on your system).
- ``conda_env.sh``: install a ``conda`` environment
with all SCICO dependencies. For GPU installation, see ``conda_env.sh -h``.


Manual Installation
Expand Down
16 changes: 8 additions & 8 deletions docs/source/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Installing a Development Version
`the jax example <https://jax.readthedocs.io/en/latest/contributing.html#contributing-code-using-pull-requests>`_)


1. Create a conda environment using Python >= 3.8.
1. Create a conda environment using Python >= 3.8:

::

Expand Down Expand Up @@ -76,7 +76,7 @@ Installing a Development Version
pip install -r examples/examples_requirements.txt # Installs example requirements
pip install -e . # Installs SCICO from the current directory in editable mode.

6. Set up ``black`` and ``isort`` pre-commit hooks
6. Set up ``black`` and ``isort`` pre-commit hooks:

::

Expand Down Expand Up @@ -150,7 +150,7 @@ NOTE: If you have added or modified an example script, see `Adding Usage Exampl

git push --set-upstream origin name-of-change

9. Create a new pull request to the ``main`` branch; see `the GitHub instructions <https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request>`_
9. Create a new pull request to the ``main`` branch; see `the GitHub instructions <https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request>`_.

10. Delete the branch after it has been merged.

Expand All @@ -163,7 +163,7 @@ existing examples to ensure that the mechanism for automatically
generating corresponding Jupyter notebooks functions correctly. In
particular:

1. The initial lines of the script should consist of a comment block, followed by a blank line, followed by a multiline string with an RST heading on the first line, e.g.
1. The initial lines of the script should consist of a comment block, followed by a blank line, followed by a multiline string with an RST heading on the first line, e.g.,

::

Expand Down Expand Up @@ -218,7 +218,7 @@ and ``scico-data`` repositories must be updated and kept in sync.

3. Convert your new example to a Jupyter notebook by changing directory to the ``scico/examples`` directory and following the instructions in ``scico/examples/README.rst``.

4. Change directory to the ``data`` directory and add/commit the new Jupyter Notebook
4. Change directory to the ``data`` directory and add/commit the new Jupyter Notebook:

::

Expand Down Expand Up @@ -254,7 +254,7 @@ scico-data repositories must be updated and kept in sync.

1. Add the ``new_data.npz`` file to the ``scico/data`` directory.

2. Navigate to the ``data`` directory and add/commit the new data file
2. Navigate to the ``data`` directory and add/commit the new data file:

::

Expand Down Expand Up @@ -288,7 +288,7 @@ Running Tests
-------------


To be able to run the tests, install `pytest` and, optionally, `pytest-runner`
To be able to run the tests, install `pytest` and, optionally, `pytest-runner`:

::

Expand Down Expand Up @@ -342,7 +342,7 @@ Test Coverage

Test coverage is a measure of the fraction of the package code that is exercised by the tests. While this should not be the primary criterion in designing tests, it is a useful tool for finding obvious areas of omission.

To be able to check test coverage, install `coverage`
To be able to check test coverage, install `coverage`:

::

Expand Down
83 changes: 58 additions & 25 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Functionals and Losses
======================
Functionals
===========

.. raw:: html

Expand All @@ -22,25 +22,31 @@ Functionals and Losses
}
</style>


A functional maps an :code:`array-like` to a scalar; abstractly, a functional is
A functional is
a mapping from :math:`\mathbb{R}^n` or :math:`\mathbb{C}^n` to :math:`\mathbb{R}`.

A functional ``f`` can have three core operations.
In SCICO, functionals are
primarily used to represent a cost to be minimized
and are represented by instances of the :class:`.Functional` class.
An instance of :class:`.Functional`, ``f``, may provide three core operations.

* Evaluation
- ``f(x)`` returns the value of the functional evaluated at :math:`\mb{x}`.
- A functional that can be evaluated has the attribute ``f.has_eval == True``.
- ``f(x)`` returns the value of the functional
evaluated at the point ``x``.
- A functional that can be evaluated
has the attribute ``f.has_eval == True``.
- Not all functionals can be evaluated: see `Plug-and-Play`_.

* Gradient
- ``f.grad(x)`` returns the gradient of the functional evaluated at :math:`\mb{x}`.
- Calculated using JAX reverse-mode automatic differentiation, exposed through :func:`scico.grad`.
- ``f.grad(x)`` returns the gradient of the functional evaluated at ``x``.
- Gradients are calculated using JAX reverse-mode automatic differentiation,
exposed through :func:`scico.grad`.
- A functional that is smooth has the attribute ``f.is_smooth == True``.
- NOTE: The gradient of a functional ``f`` can be evaluated even if ``f.is_smooth == False``. All that is required is that the functional can be evaluated, ``f.has_eval == True``. However, the result may not be a valid gradient (or subgradient) for all inputs :math:`\mb{x}`.

- NOTE: The gradient of a functional ``f`` can be evaluated even if ``f.is_smooth == False``.
All that is required is that the functional can be evaluated, ``f.has_eval == True``.
However, the result may not be a valid gradient (or subgradient) for all inputs.
* Proximal operator
- The proximal operator of a functional :math:`f : \mathbb{R}^n \to \mathbb{R}` is the mapping
- ``f.prox(v, lam)`` returns the result of the scaled proximal operator
at ``v`` with scale ``lam``.
- The proximal operator of a functional :math:`f : \mathbb{R}^n \to \mathbb{R}` is the mapping
:math:`\mathrm{prox}_f : \mathbb{R}^n \times \mathbb{R} \to \mathbb{R}^n` defined as

.. math::
Expand All @@ -51,14 +57,17 @@ A functional ``f`` can have three core operations.
Plug-and-Play
-------------

* For the Plug-and-Play framework :cite:`sreehari-2016-plug`, we encapsulate denoisers/CNNs in a Functional object that **cannot be evaluated**.
* Only the proximal operator is exposed.
For the plug-and-play framework :cite:`sreehari-2016-plug`,
we encapsulate generic denoisers including CNNs
in :class:`.Functional` objects that **cannot be evaluated**.
The denoiser is applied via the the proximal operator.
For examples, see :ref:`example_notebooks`.


Proximal Calculus
-----------------

We support a limited subset of proximal calculus rules.
We support a limited subset of proximal calculus rules:


Scaled Functionals
Expand All @@ -75,9 +84,15 @@ determine the proximal method of ``c * f`` as
&= \mathrm{prox}_{f} (v, c \lambda)
\end{align}

Note that we have made no assumptions regarding homogeneity of ``f``; rather, only that the proximal method of ``f`` is given in the parameterized form :math:`\mathrm{prox}_{c f}`.
Note that we have made no assumptions regarding homogeneity of ``f``;
rather, only that the proximal method of ``f`` is given
in the parameterized form :math:`\mathrm{prox}_{c f}`.

In SCICO, multiplying a :class:`.Functional` by a scalar will return a :class:`.ScaledFunctional`. This :class:`.ScaledFunctional` retains the ``has_eval, is_smooth``, and ``has_prox`` attributes from the original :class:`.Functional`, but the proximal method is modified to accomodate the additional scalar.
In SCICO, multiplying a :class:`.Functional` by a scalar
will return a :class:`.ScaledFunctional`.
This :class:`.ScaledFunctional` retains the ``has_eval``, ``is_smooth``, and ``has_prox`` attributes
from the original :class:`.Functional`,
but the proximal method is modified to accomodate the additional scalar.


Separable Functionals
Expand All @@ -89,7 +104,9 @@ of functionals :math:`f_i : \mathbb{C}^{N_i} \to \mathbb{R}` with :math:`\sum_i
.. math::
f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots + f_N(\mb{x}_N)

The proximal operator of a separable :math:`f` can be written in terms of the proximal operators of the :math:`f_i` (see Theorem 6.6 of :cite:`beck-2017-first`):
The proximal operator of a separable :math:`f` can be written
in terms of the proximal operators of the :math:`f_i`
(see Theorem 6.6 of :cite:`beck-2017-first`):

.. math::
\mathrm{prox}_f(\mb{x}, \lambda)
Expand All @@ -106,10 +123,14 @@ Separable Functionals are implemented in the :class:`.SeparableFunctional` class

Adding New Functionals
----------------------
To add a new functional,
create a class which

1. inherits from base :class:`.Functional`;
2. has ``has_eval``, ``is_smooth``, and ``has_prox`` flags;
3. has ``_eval`` and ``prox`` methods, as necessary.

1. Inherit from base functional
2. Set ``has_eval``, ``is_smooth``, and ``has_prox`` flags.
3. Add ``_eval`` and ``prox`` methods, as necessary.
For example,

::

Expand All @@ -129,6 +150,18 @@ Adding New Functionals
Losses
------

.. todo::
In SCICO, a loss is a special type of functional

Content missing here
.. math::
f(\mb{x}) = a l( \mb{y}, A(\mb{x}) )

where :math:`a` is a scale parameter,
:math:`l` is a functional,
:math:`\mb{y}` is a set of measurements,
and :math:`A` is an operator.
SCICO uses the class :class:`.Loss` to represent losses.
Loss functionals commonly arrise in the context of solving
inverse problems in scientific imaging,
where they are used to represent the mismatch
between predicted measurements :math:`A(\mb{x})`
and actual ones :math:`\mb{y}`.
22 changes: 11 additions & 11 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Double Precision
By default, JAX enforces single-precision numbers. Double precision can be enabled in one of two ways:

1. Setting the environment variable ``JAX_ENABLE_X64=TRUE`` before launching python.
2. Manually set the ``jax_enable_x64`` flag **at program startup**; that is, **before** importing SCICO.
2. Manually setting the ``jax_enable_x64`` flag **at program startup**; that is, **before** importing SCICO.

::

Expand All @@ -90,7 +90,7 @@ By default, JAX enforces single-precision numbers. Double precision can be enabl
import scico # continue as usual


For more information, `see the JAX notes on double precision <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_
For more information, see the `JAX notes on double precision <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_.


Random Number Generation
Expand Down Expand Up @@ -146,11 +146,11 @@ The function :func:`scico.grad` returns the expected gradient, that is, the conj
JAX gradient. For further discussion, see this
`JAX issue <https://github.com/google/jax/issues/4891>`_.

As a concrete example, consider the function :math:`f(x) = \frac{1}{2}\norm{A
x}_2^2` where :math:`A` is a complex matrix. The gradient of :math:`f` is
usually given :math:`(\nabla f)(x) = A^H A x`, where :math:`A^H` is the
conjugate transpose of :math:`A`. Applying ``jax.grad`` to :math:`f` will yield
:math:`(A^H A x)^*`, where :math:`*` denotes complex conjugation.
As a concrete example, consider the function :math:`f(x) = \frac{1}{2}\norm{\mb{A}
\mb{x}}_2^2` where :math:`\mb{A}` is a complex matrix. The gradient of :math:`f` is
usually given :math:`(\nabla f)(\mb{x}) = \mb{A}^H \mb{A} \mb{x}`, where :math:`\mb{A}^H` is the
conjugate transpose of :math:`\mb{A}`. Applying ``jax.grad`` to :math:`f` will yield
:math:`(\mb{A}^H \mb{A} \mb{x})^*`, where :math:`*` denotes complex conjugation.

The following code demonstrates the use of ``jax.grad`` and :func:`scico.grad`:

Expand All @@ -173,11 +173,11 @@ The following code demonstrates the use of ``jax.grad`` and :func:`scico.grad`:
Non-differentiable Functionals and scico.grad
---------------------------------------------

* :func:`scico.grad` can be applied to any function, but has undefined behavior for
non-differentiable functions.
* For non-differerentiable functions, :func:`scico.grad` may or may not return a valid subgradient. As an example, ``scico.grad(snp.abs)(0.) = 0``, which is a valid subgradient. However, ``scico.grad(snp.linalg.norm)([0., 0.]) = [nan, nan]``, which is not a valid subgradient of this function.
* Differentiable functions that are written as the composition of a differentiable and non-differentiable function should be avoided. As an example, :math:`f(x) = \norm{x}_2^2` can be implemented in as ``f = lambda x: snp.linalg.norm(x)**2``. This involves first calculating the non-squared :math:`\ell_2` norm, then squaring it. The un-squared :math:`\ell_2` norm is not differentiable at zero.
:func:`scico.grad` can be applied to any function, but has undefined behavior for
non-differentiable functions.
For non-differerentiable functions, :func:`scico.grad` may or may not return a valid subgradient. As an example, ``scico.grad(snp.abs)(0.) = 0``, which is a valid subgradient. However, ``scico.grad(snp.linalg.norm)([0., 0.]) = [nan, nan]``.

Differentiable functions that are written as the composition of a differentiable and non-differentiable function should be avoided. As an example, :math:`f(x) = \norm{x}_2^2` can be implemented in as ``f = lambda x: snp.linalg.norm(x)**2``. This involves first calculating the non-squared :math:`\ell_2` norm, then squaring it. The un-squared :math:`\ell_2` norm is not differentiable at zero.
When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:

::
Expand Down
Loading