Skip to content

Commit

Permalink
Limited verify_grad support for multiple output Ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Feb 14, 2024
1 parent b6c79fd commit 3f1f902
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
54 changes: 39 additions & 15 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ def verify_grad(
mode: Optional[Union["Mode", str]] = None,
cast_to_output_type: bool = False,
no_debug_ref: bool = True,
sum_outputs=False,
):
"""Test a gradient by Finite Difference Method. Raise error on failure.
Expand Down Expand Up @@ -1722,7 +1723,9 @@ def verify_grad(
float16 is not handled here.
no_debug_ref
Don't use `DebugMode` for the numerical gradient function.
sum_outputs: bool, default False
If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
multiple outputs.
Notes
-----
This function does not support multiple outputs. In `tests.scan.test_basic`
Expand Down Expand Up @@ -1782,7 +1785,7 @@ def verify_grad(
# fun can be either a function or an actual Op instance
o_output = fun(*tensor_pt)

if isinstance(o_output, list):
if isinstance(o_output, list) and not sum_outputs:
raise NotImplementedError(
"Can't (yet) auto-test the gradient of a function with multiple outputs"
)
Expand All @@ -1793,7 +1796,7 @@ def verify_grad(
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
o_fn_out = o_fn(*[p.copy() for p in pt])

if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list) and not sum_outputs:
raise TypeError(
"It seems like you are trying to use verify_grad "
"on an Op or a function which outputs a list: there should"
Expand All @@ -1802,33 +1805,45 @@ def verify_grad(

# random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient
def random_projection():
plain = rng.random(o_fn_out.shape) + 0.5
if cast_to_output_type and o_output.dtype == "float32":
return np.array(plain, o_output.dtype)
def random_projection(shape, dtype):
plain = rng.random(shape) + 0.5
if cast_to_output_type and dtype == "float32":
return np.array(plain, dtype)

Check warning on line 1811 in pytensor/gradient.py

View check run for this annotation

Codecov / codecov/patch

pytensor/gradient.py#L1811

Added line #L1811 was not covered by tests
return plain

t_r = shared(random_projection(), borrow=True)
t_r.name = "random_projection"

# random projection of o onto t_r
# This sum() is defined above, it's not the builtin sum.
cost = pytensor.tensor.sum(t_r * o_output)
if sum_outputs:
t_rs = [
shared(
value=random_projection(o.shape, o.dtype),
borrow=True,
name=f"random_projection_{i}",
)
for i, o in enumerate(o_fn_out)
]
cost = pytensor.tensor.sum(
[pytensor.tensor.sum(x * y) for x, y in zip(t_rs, o_output)]
)
else:
t_r = shared(
value=random_projection(o_fn_out.shape, o_fn_out.dtype),
borrow=True,
name="random_projection",
)
cost = pytensor.tensor.sum(t_r * o_output)

if no_debug_ref:
mode_for_cost = mode_not_slow(mode)
else:
mode_for_cost = mode

cost_fn = fn_maker(tensor_pt, cost, name="gradient.py cost", mode=mode_for_cost)

symbolic_grad = grad(cost, tensor_pt, disconnected_inputs="ignore")

grad_fn = fn_maker(tensor_pt, symbolic_grad, name="gradient.py symbolic grad")

for test_num in range(n_tests):
num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps, out_type)

analytic_grad = grad_fn(*[p.copy() for p in pt])

# Since `tensor_pt` is a list, `analytic_grad` should be one too.
Expand All @@ -1853,7 +1868,16 @@ def random_projection():

# get new random projection for next test
if test_num < n_tests - 1:
t_r.set_value(random_projection(), borrow=True)
if sum_outputs:
for r in t_rs:
r.set_value(
random_projection(r.get_value().shape, r.get_value().dtype)
)
else:
t_r.set_value(
random_projection(t_r.get_value().shape, t_r.get_value().dtype),
borrow=True,
)


class GradientError(Exception):
Expand Down
1 change: 1 addition & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def test_grad(self, compute_uv, full_matrices, shape, batched):
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
sum_outputs=True,
)

else:
Expand Down

0 comments on commit 3f1f902

Please sign in to comment.