Skip to content

Commit

Permalink
Autogenerate parameter types in documentation from python typehints.
Browse files Browse the repository at this point in the history
For example: if the method is...

```python
def forward(self, x1: Tensor, x2: Optional[Tensor]) -> Union[Tensor, LinearOperator]:
    r"""
    Does the forward thing.

    :param x1: The x1 arg
    :param x2: The x2 arg
    :return: The forward stuff.
    """
    # ...
```

The resulting documentation will include the appropriate types:

```
Parameters:

  - x1 (torch.Tensor) - The x1 arg
  - x2 (torch.Tensor, optional) - The x2 arg

Return type: torch.Tensor or LinearOperator

Returns: The forward stuff.
```

This (hopefully) should prevent a lot of duplicate effort on our end.

This commit also refactors the gpytorch.kernels.Kernel and
gpytorch.distributions.MultivariateNormal docs to utilize
auto-typing.
  • Loading branch information
gpleiss committed Sep 7, 2022
1 parent 019228d commit 3d1b3e5
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 211 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ ipython
ipykernel
sphinx
sphinx_rtd_theme
sphinx_autodoc_typehints
nbsphinx
m2r2
pyro-ppl
109 changes: 107 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import shutil
import sys
import sphinx_rtd_theme # noqa
import warnings
from typing import ForwardRef


def read(*names, **kwargs):
Expand Down Expand Up @@ -80,16 +82,25 @@ def find_version(*file_paths):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.githubpages",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
'sphinx.ext.napoleon',
"sphinx.ext.viewcode",
"sphinx.ext.githubpages",
"sphinx.ext.autodoc",
"sphinx_autodoc_typehints",
"nbsphinx",
"m2r2",
]

# Configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None),
}

# Disable docstring inheritance
autodoc_inherit_docstrings = False

Expand Down Expand Up @@ -210,3 +221,97 @@ def find_version(*file_paths):
"Miscellaneous",
)
]


# -- Function to format typehints ----------------------------------------------
# Adapted from
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
def _process(annotation, config):
"""
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
This function is a bit hacky, and specific to the type annotations we use most frequently.
This function is recursive.
"""
# Simple/base case: any string annotation is ready to go
if type(annotation) == str:
return annotation

# Convert Ellipsis into "..."
elif annotation == Ellipsis:
return "..."

# Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
# For external classes, the format will be e.g. "torch.Tensor"
# For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
# For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
elif hasattr(annotation, "__name__"):
module = annotation.__module__ + "."
if module.split(".")[0] == "linear_operator":
if annotation.__name__.endswith("LinearOperator"):
module = f"~linear_operator."
elif annotation.__name__.endswith("LinearOperator"):
module = f"~linear_operator.operators."
else:
module = "~" + module
elif module.split(".")[0] == "gpytorch":
module = "~" + module
elif module == "builtins.":
module = ""
res = f"{module}{annotation.__name__}"

# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
# Also, convert any Optional[*A*] into "*A*, optional"
elif str(annotation).startswith("typing.Union"):
is_optional_str = ""
args = list(annotation.__args__)
# Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype]
# This catches this case
if args[-1] is type(None): # noqa E721
del args[-1]
is_optional_str = ", optional"
processed_args = [_process(arg, config) for arg in args]
res = " or ".join(processed_args) + is_optional_str

# Convert any Tuple[*A*, *B*] into "(*A*, *B*)"
elif str(annotation).startswith("typing.Tuple"):
args = list(annotation.__args__)
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"

# Convert any List[*A*] into "list(*A*)"
elif str(annotation).startswith("typing.List"):
arg = annotation.__args__[0]
res = "list(" + _process(arg, config) + ")"

# Convert any Iterable[*A*] into "iterable(*A*)"
elif str(annotation).startswith("typing.Iterable"):
arg = annotation.__args__[0]
res = "iterable(" + _process(arg, config) + ")"

# Handle "Callable"
elif str(annotation).startswith("typing.Callable"):
res = "callable"

# Handle "Any"
elif str(annotation).startswith("typing.Any"):
res = ""

# Special cases for forward references.
# This is brittle, as it only contains case for a select few forward refs
# All others that aren't caught by this are handled by the default case
elif isinstance(annotation, ForwardRef):
res = str(annotation.__forward_arg__)

# For everything we didn't catch: use the simplist string representation
else:
warnings.warn(f"No rule for {annotation}. Using default resolution...", RuntimeWarning)
res = str(annotation)

return res


# -- Options for typehints ----------------------------------------------
always_document_param_types = True
# typehints_use_rtype = False
typehints_defaults = None # or "comma"
simplify_optional_unions = False
typehints_formatter = _process
1 change: 1 addition & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ MultivariateNormal

.. autoclass:: MultivariateNormal
:members:
:special-members: __getitem__


MultitaskMultivariateNormal
Expand Down
1 change: 1 addition & 0 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Kernel

.. autoclass:: Kernel
:members:
:special-members: __call__, __getitem__

Standard Kernels
-----------------------------
Expand Down
Loading

0 comments on commit 3d1b3e5

Please sign in to comment.