-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat:
value_and_grad
with support for auxiliary data
- Loading branch information
1 parent
562e551
commit 996fa36
Showing
4 changed files
with
96 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
tests/test_plugins/autograd/test_differential_operators.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import autograd.numpy as np | ||
from autograd import value_and_grad as value_and_grad_ag | ||
from numpy.testing import assert_allclose | ||
from tidy3d.plugins.autograd.differential_operators import value_and_grad | ||
|
||
|
||
def test_value_and_grad(rng): | ||
"""Test the custom value_and_grad function against autograd's implementation""" | ||
x = rng.random(10) | ||
aux_val = "aux" | ||
|
||
vg_fun = value_and_grad(lambda x: (np.linalg.norm(x), aux_val), has_aux=True) | ||
vg_fun_ag = value_and_grad_ag(lambda x: np.linalg.norm(x)) | ||
|
||
(v, g), aux = vg_fun(x) | ||
v_ag, g_ag = vg_fun_ag(x) | ||
|
||
# assert that values and gradients match | ||
assert_allclose(v, v_ag) | ||
assert_allclose(g, g_ag) | ||
|
||
# check that auxiliary output is correctly returned | ||
assert aux == aux_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .differential_operators import value_and_grad | ||
|
||
__all__ = [ | ||
"value_and_grad", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Any, Callable | ||
|
||
from autograd import value_and_grad as value_and_grad_ag | ||
from autograd.builtins import tuple as atuple | ||
from autograd.core import make_vjp | ||
from autograd.extend import vspace | ||
from autograd.wrap_util import unary_to_nary | ||
from numpy.typing import ArrayLike | ||
|
||
__all__ = [ | ||
"value_and_grad", | ||
] | ||
|
||
|
||
@unary_to_nary | ||
def value_and_grad( | ||
fun: Callable, x: ArrayLike, *, has_aux: bool = False | ||
) -> tuple[tuple[float, ArrayLike], Any]: | ||
"""Returns a function that returns both value and gradient. | ||
This function wraps and extends autograd's 'value_and_grad' function by adding | ||
support for auxiliary data. | ||
Parameters | ||
---------- | ||
fun : Callable | ||
The function to differentiate. Should take a single argument and return | ||
a scalar value, or a tuple where the first element is a scalar value if has_aux is True. | ||
x : ArrayLike | ||
The point at which to evaluate the function and its gradient. | ||
has_aux : bool = False | ||
If True, the function returns auxiliary data as the second element of a tuple. | ||
Returns | ||
------- | ||
tuple[tuple[float, ArrayLike], Any] | ||
A tuple containing: | ||
- A tuple with the function value (float) and its gradient (ArrayLike) | ||
- The auxiliary data returned by the function (if has_aux is True) | ||
Raises | ||
------ | ||
TypeError | ||
If the function does not return a scalar value. | ||
Notes | ||
----- | ||
This function uses autograd for automatic differentiation. If the function | ||
does not return auxiliary data (has_aux is False), it delegates to autograd's | ||
value_and_grad function. The main extension is the support for auxiliary data | ||
when has_aux is True. | ||
""" | ||
if not has_aux: | ||
return value_and_grad_ag(fun)(x) | ||
|
||
vjp, (ans, aux) = make_vjp(lambda x: atuple(fun(x)), x) | ||
|
||
if not vspace(ans).size == 1: | ||
raise TypeError( | ||
"value_and_grad only applies to real scalar-output " | ||
"functions. Try jacobian, elementwise_grad or " | ||
"holomorphic_grad." | ||
) | ||
|
||
return (ans, vjp((vspace(ans).ones(), None))), aux |