Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Fix docstrings in gradient.py #415

Merged
merged 2 commits into from
Sep 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ def Rop(

Returns
-------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression such obeying
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all
coordinates of the tensor elements.
If `wrt` is a list/tuple, then return a list/tuple with the results.
If `f` is a list/tuple, then return a list/tuple with the results.
"""

if not isinstance(wrt, (list, tuple)):
Expand Down Expand Up @@ -384,6 +385,7 @@ def Lop(

Returns
-------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression satisfying
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional
Expand Down Expand Up @@ -481,10 +483,10 @@ def grad(

Returns
-------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression for the gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not differentiable with
respect to the output, then a zero variable is returned.

"""
t0 = time.perf_counter()

Expand Down Expand Up @@ -701,7 +703,6 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):

Parameters
----------

wrt : list of variables
Gradients are computed with respect to `wrt`.

Expand Down Expand Up @@ -876,7 +877,6 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):

(A variable in consider_constant is not a function of
anything)

"""

# Validate and format consider_constant
Expand Down Expand Up @@ -1035,7 +1035,6 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
-------
list of Variables
A list of gradients corresponding to `wrt`

"""
# build a dict mapping node to the terms node contributes to each of
# its inputs' gradients
Expand Down Expand Up @@ -1423,8 +1422,9 @@ def access_grad_cache(var):


def _float_zeros_like(x):
"""Like zeros_like, but forces the object to have a
a floating point dtype"""
"""Like zeros_like, but forces the object to have
a floating point dtype
"""

rval = x.zeros_like()

Expand All @@ -1436,7 +1436,8 @@ def _float_zeros_like(x):

def _float_ones_like(x):
"""Like ones_like, but forces the object to have a
floating point dtype"""
floating point dtype
"""

dtype = x.type.dtype
if dtype not in pytensor.tensor.type.float_dtypes:
Expand Down Expand Up @@ -1613,7 +1614,6 @@ def abs_rel_errors(self, g_pt):

Corresponding ndarrays in `g_pt` and `self.gf` must have the same
shape or ValueError is raised.

"""
if len(g_pt) != len(self.gf):
raise ValueError("argument has wrong number of elements", len(g_pt))
Expand Down Expand Up @@ -1740,7 +1740,6 @@ def verify_grad(
This function does not support multiple outputs. In `tests.scan.test_basic`
there is an experimental `verify_grad` that covers that case as well by
using random projections.

"""
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared
Expand Down Expand Up @@ -2267,7 +2266,6 @@ def grad_clip(x, lower_bound, upper_bound):
-----
We register an opt in tensor/opt.py that remove the GradClip.
So it have 0 cost in the forward and only do work in the grad.

"""
return GradClip(lower_bound, upper_bound)(x)

Expand Down
Loading