Skip to content

Commit

Permalink
feat: value_and_grad with support for auxiliary data
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Aug 7, 2024
1 parent 562e551 commit 996fa36
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Added value_and_grad function to the autograd plugin, importable via `from tidy3d.plugins.autograd import value_and_grad`. Supports differentiating functions with auxiliary data (`value_and_grad(f, has_aux=True)`).

## [2.7.2] - 2024-08-07

### Added
Expand Down
23 changes: 23 additions & 0 deletions tests/test_plugins/autograd/test_differential_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import autograd.numpy as np
from autograd import value_and_grad as value_and_grad_ag
from numpy.testing import assert_allclose
from tidy3d.plugins.autograd.differential_operators import value_and_grad


def test_value_and_grad(rng):
"""Test the custom value_and_grad function against autograd's implementation"""
x = rng.random(10)
aux_val = "aux"

vg_fun = value_and_grad(lambda x: (np.linalg.norm(x), aux_val), has_aux=True)
vg_fun_ag = value_and_grad_ag(lambda x: np.linalg.norm(x))

(v, g), aux = vg_fun(x)
v_ag, g_ag = vg_fun_ag(x)

# assert that values and gradients match
assert_allclose(v, v_ag)
assert_allclose(g, g_ag)

# check that auxiliary output is correctly returned
assert aux == aux_val
5 changes: 5 additions & 0 deletions tidy3d/plugins/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .differential_operators import value_and_grad

__all__ = [
"value_and_grad",
]
65 changes: 65 additions & 0 deletions tidy3d/plugins/autograd/differential_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any, Callable

from autograd import value_and_grad as value_and_grad_ag
from autograd.builtins import tuple as atuple
from autograd.core import make_vjp
from autograd.extend import vspace
from autograd.wrap_util import unary_to_nary
from numpy.typing import ArrayLike

__all__ = [
"value_and_grad",
]


@unary_to_nary
def value_and_grad(
fun: Callable, x: ArrayLike, *, has_aux: bool = False
) -> tuple[tuple[float, ArrayLike], Any]:
"""Returns a function that returns both value and gradient.
This function wraps and extends autograd's 'value_and_grad' function by adding
support for auxiliary data.
Parameters
----------
fun : Callable
The function to differentiate. Should take a single argument and return
a scalar value, or a tuple where the first element is a scalar value if has_aux is True.
x : ArrayLike
The point at which to evaluate the function and its gradient.
has_aux : bool = False
If True, the function returns auxiliary data as the second element of a tuple.
Returns
-------
tuple[tuple[float, ArrayLike], Any]
A tuple containing:
- A tuple with the function value (float) and its gradient (ArrayLike)
- The auxiliary data returned by the function (if has_aux is True)
Raises
------
TypeError
If the function does not return a scalar value.
Notes
-----
This function uses autograd for automatic differentiation. If the function
does not return auxiliary data (has_aux is False), it delegates to autograd's
value_and_grad function. The main extension is the support for auxiliary data
when has_aux is True.
"""
if not has_aux:
return value_and_grad_ag(fun)(x)

vjp, (ans, aux) = make_vjp(lambda x: atuple(fun(x)), x)

if not vspace(ans).size == 1:
raise TypeError(
"value_and_grad only applies to real scalar-output "
"functions. Try jacobian, elementwise_grad or "
"holomorphic_grad."
)

return (ans, vjp((vspace(ans).ones(), None))), aux

0 comments on commit 996fa36

Please sign in to comment.