Skip to content

Commit

Permalink
Add documentation for functional module
Browse files Browse the repository at this point in the history
This patch adds documentation for functional module.
  • Loading branch information
ybubnov committed Jul 3, 2024
1 parent 3faaa0f commit 3f8b89e
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 8 deletions.
9 changes: 9 additions & 0 deletions docs/_templates/function.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. role:: hidden
:class: hidden-section

.. currentmodule:: {{ module }}


{{ name | underline }}

.. autofunction:: {{ name }}
35 changes: 35 additions & 0 deletions docs/reference/functional.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
.. _functional:

torch_geopooling.functional
===========================

.. automodule:: torch_geopooling.functional
.. currentmodule:: torch_geopooling.functional

TBD.


Pooling functions
-----------------

.. autosummary::
:nosignatures:
:toctree: generated
:template: function.rst

torch_geopooling.functional.avg_quad_pool2d
torch_geopooling.functional.max_quad_pool2d
torch_geopooling.functional.quad_pool2d


Adaptive pooling functions
--------------------------

.. autosummary::
:nosignatures:
:toctree: generated
:template: function.rst

torch_geopooling.functional.adaptive_avg_quad_pool2d
torch_geopooling.functional.adaptive_max_quad_pool2d
torch_geopooling.functional.adaptive_quad_pool2d
5 changes: 3 additions & 2 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.. _torch_geopooling:
.. _torchgeopooling:

Reference
=========
API Reference
=============

.. automodule:: torch_geopooling
.. currentmodule:: torch_geopooling
Expand All @@ -14,4 +14,5 @@ applications using neural networks.
:maxdepth: 1

nn
functional
transforms
91 changes: 85 additions & 6 deletions torch_geopooling/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from textwrap import dedent, indent
from functools import partial
from inspect import signature
from typing import Callable, NamedTuple, Optional, Tuple

import torch
Expand All @@ -33,6 +36,17 @@
]


def __def__(fn: Callable, doc: str) -> Callable:
f = partial(fn)
f.__doc__ = doc + indent(dedent(fn.__doc__ or ""), " ")
f.__module__ = fn.__module__
f.__annotations__ = fn.__annotations__
f.__signature__ = signature(fn) # type: ignore
f.__defaults__ = fn.__defaults__ # type: ignore
f.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
return f


class FunctionParams(NamedTuple):
max_terminal_nodes: Optional[int] = None
max_depth: Optional[int] = None
Expand Down Expand Up @@ -121,6 +135,21 @@ def func(
capacity: Optional[int] = None,
precision: Optional[int] = None,
) -> return_types.quad_pool2d:
"""
Args:
tiles: Tiles tensor representing tiles of a quadtree (both, internal and terminal).
weight: Weights tensor associated with each tile of a quadtree.
input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
training: True, when executed during training, and False otherwise.
max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
maximum is reached, internal nodes are no longer sub-divided and tree stops
growing.
max_depth: Maximum depth of the quadtree. Default: 17.
capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
depth of the tree grows. Default: 1.
precision: Optional rounding of the input coordinates. Default: 7.
"""
params = FunctionParams(
max_terminal_nodes=max_terminal_nodes,
max_depth=max_depth,
Expand All @@ -147,9 +176,27 @@ class AvgQuadPool2d(Function):
backward_impl = _C.avg_quad_pool2d_backward


quad_pool2d = QuadPool2d.func
max_quad_pool2d = MaxQuadPool2d.func
avg_quad_pool2d = AvgQuadPool2d.func
quad_pool2d = __def__(
QuadPool2d.func,
"""Lookup index over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.QuadPool2d` for more details.
""",
)
max_quad_pool2d = __def__(
MaxQuadPool2d.func,
"""Maximum pooling over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.MaxQuadPool2d` for more details.
""",
)
avg_quad_pool2d = __def__(
AvgQuadPool2d.func,
"""Average pooling over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.AvgQuadPool2d` for more details.
""",
)


class AdaptiveFunction(autograd.Function):
Expand Down Expand Up @@ -242,6 +289,20 @@ def func(
capacity: Optional[int] = None,
precision: Optional[int] = None,
) -> return_types.adaptive_quad_pool2d:
"""
Args:
weight: Weights tensor associated with each tile of a quadtree.
input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
training: True, when executed during training, and False otherwise.
max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
maximum is reached, internal nodes are no longer sub-divided and tree stops
growing.
max_depth: Maximum depth of the quadtree. Default: 17.
capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
depth of the tree grows. Default: 1.
precision: Optional rounding of the input coordinates. Default: 7.
"""
params = FunctionParams(
max_terminal_nodes=max_terminal_nodes,
max_depth=max_depth,
Expand All @@ -268,6 +329,24 @@ class AdaptiveAvgQuadPool2d(AdaptiveFunction):
backward_impl = _C.avg_quad_pool2d_backward


adaptive_quad_pool2d = AdaptiveQuadPool2d.func
adaptive_max_quad_pool2d = AdaptiveMaxQuadPool2d.func
adaptive_avg_quad_pool2d = AdaptiveAvgQuadPool2d.func
adaptive_quad_pool2d = __def__(
AdaptiveQuadPool2d.func,
"""Adaptive lookup index over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.AdaptiveQuadPool2d` for more details.
""",
)
adaptive_max_quad_pool2d = __def__(
AdaptiveMaxQuadPool2d.func,
"""Adaptive maximum pooling over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.AdaptiveMaxQuadPool2d` for more details.
""",
)
adaptive_avg_quad_pool2d = __def__(
AdaptiveAvgQuadPool2d.func,
"""Adaptive average pooling over quadtree decomposition of input 2D coordinates.
See :class:`torch_geopooling.nn.AdaptiveAvgQuadPool2d` for more details.
""",
)

0 comments on commit 3f8b89e

Please sign in to comment.