Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autogenerate parameter types in documentation from python typehints #2125

Merged
merged 1 commit into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ ipython
ipykernel
sphinx
sphinx_rtd_theme
sphinx_autodoc_typehints
nbsphinx
m2r2
pyro-ppl
109 changes: 107 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import shutil
import sys
import sphinx_rtd_theme # noqa
import warnings
from typing import ForwardRef


def read(*names, **kwargs):
Expand Down Expand Up @@ -80,16 +82,25 @@ def find_version(*file_paths):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.githubpages",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
'sphinx.ext.napoleon',
"sphinx.ext.viewcode",
"sphinx.ext.githubpages",
"sphinx.ext.autodoc",
"sphinx_autodoc_typehints",
"nbsphinx",
"m2r2",
]

# Configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None),
}

# Disable docstring inheritance
autodoc_inherit_docstrings = False

Expand Down Expand Up @@ -210,3 +221,97 @@ def find_version(*file_paths):
"Miscellaneous",
)
]


# -- Function to format typehints ----------------------------------------------
# Adapted from
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
def _process(annotation, config):
"""
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
This function is a bit hacky, and specific to the type annotations we use most frequently.
This function is recursive.
"""
# Simple/base case: any string annotation is ready to go
if type(annotation) == str:
return annotation

# Convert Ellipsis into "..."
elif annotation == Ellipsis:
return "..."

# Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
# For external classes, the format will be e.g. "torch.Tensor"
# For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
# For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
elif hasattr(annotation, "__name__"):
module = annotation.__module__ + "."
if module.split(".")[0] == "linear_operator":
if annotation.__name__.endswith("LinearOperator"):
module = "~linear_operator."
elif annotation.__name__.endswith("LinearOperator"):
module = "~linear_operator.operators."
else:
module = "~" + module
elif module.split(".")[0] == "gpytorch":
module = "~" + module
elif module == "builtins.":
module = ""
res = f"{module}{annotation.__name__}"

# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
# Also, convert any Optional[*A*] into "*A*, optional"
elif str(annotation).startswith("typing.Union"):
is_optional_str = ""
args = list(annotation.__args__)
# Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype]
# This catches this case
if args[-1] is type(None): # noqa E721
del args[-1]
is_optional_str = ", optional"
processed_args = [_process(arg, config) for arg in args]
res = " or ".join(processed_args) + is_optional_str

# Convert any Tuple[*A*, *B*] into "(*A*, *B*)"
elif str(annotation).startswith("typing.Tuple"):
args = list(annotation.__args__)
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"

# Convert any List[*A*] into "list(*A*)"
elif str(annotation).startswith("typing.List"):
arg = annotation.__args__[0]
res = "list(" + _process(arg, config) + ")"

# Convert any Iterable[*A*] into "iterable(*A*)"
elif str(annotation).startswith("typing.Iterable"):
arg = annotation.__args__[0]
res = "iterable(" + _process(arg, config) + ")"

# Handle "Callable"
elif str(annotation).startswith("typing.Callable"):
res = "callable"

# Handle "Any"
elif str(annotation).startswith("typing.Any"):
res = ""

# Special cases for forward references.
# This is brittle, as it only contains case for a select few forward refs
# All others that aren't caught by this are handled by the default case
elif isinstance(annotation, ForwardRef):
res = str(annotation.__forward_arg__)

# For everything we didn't catch: use the simplist string representation
else:
warnings.warn(f"No rule for {annotation}. Using default resolution...", RuntimeWarning)
res = str(annotation)

return res


# -- Options for typehints ----------------------------------------------
always_document_param_types = True
# typehints_use_rtype = False
typehints_defaults = None # or "comma"
simplify_optional_unions = False
typehints_formatter = _process
1 change: 1 addition & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ MultivariateNormal

.. autoclass:: MultivariateNormal
:members:
:special-members: __getitem__


MultitaskMultivariateNormal
Expand Down
1 change: 1 addition & 0 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Kernel

.. autoclass:: Kernel
:members:
:special-members: __call__, __getitem__

Standard Kernels
-----------------------------
Expand Down
Loading