From 996fa3692847cd538eab8f3437921eb35bbe8293 Mon Sep 17 00:00:00 2001 From: Yannick Augenstein Date: Wed, 7 Aug 2024 17:24:45 +0200 Subject: [PATCH] feat: `value_and_grad` with support for auxiliary data --- CHANGELOG.md | 3 + .../autograd/test_differential_operators.py | 23 +++++++ tidy3d/plugins/autograd/__init__.py | 5 ++ .../autograd/differential_operators.py | 65 +++++++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 tests/test_plugins/autograd/test_differential_operators.py create mode 100644 tidy3d/plugins/autograd/differential_operators.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 063e84875..c51dacaf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Added value_and_grad function to the autograd plugin, importable via `from tidy3d.plugins.autograd import value_and_grad`. Supports differentiating functions with auxiliary data (`value_and_grad(f, has_aux=True)`). + ## [2.7.2] - 2024-08-07 ### Added diff --git a/tests/test_plugins/autograd/test_differential_operators.py b/tests/test_plugins/autograd/test_differential_operators.py new file mode 100644 index 000000000..21fbe1b1a --- /dev/null +++ b/tests/test_plugins/autograd/test_differential_operators.py @@ -0,0 +1,23 @@ +import autograd.numpy as np +from autograd import value_and_grad as value_and_grad_ag +from numpy.testing import assert_allclose +from tidy3d.plugins.autograd.differential_operators import value_and_grad + + +def test_value_and_grad(rng): + """Test the custom value_and_grad function against autograd's implementation""" + x = rng.random(10) + aux_val = "aux" + + vg_fun = value_and_grad(lambda x: (np.linalg.norm(x), aux_val), has_aux=True) + vg_fun_ag = value_and_grad_ag(lambda x: np.linalg.norm(x)) + + (v, g), aux = vg_fun(x) + v_ag, g_ag = vg_fun_ag(x) + + # assert that values and gradients match + assert_allclose(v, v_ag) + assert_allclose(g, g_ag) + + # check that auxiliary output is correctly returned + assert aux == aux_val diff --git a/tidy3d/plugins/autograd/__init__.py b/tidy3d/plugins/autograd/__init__.py index e69de29bb..dcbf05eb0 100644 --- a/tidy3d/plugins/autograd/__init__.py +++ b/tidy3d/plugins/autograd/__init__.py @@ -0,0 +1,5 @@ +from .differential_operators import value_and_grad + +__all__ = [ + "value_and_grad", +] diff --git a/tidy3d/plugins/autograd/differential_operators.py b/tidy3d/plugins/autograd/differential_operators.py new file mode 100644 index 000000000..e83c439e1 --- /dev/null +++ b/tidy3d/plugins/autograd/differential_operators.py @@ -0,0 +1,65 @@ +from typing import Any, Callable + +from autograd import value_and_grad as value_and_grad_ag +from autograd.builtins import tuple as atuple +from autograd.core import make_vjp +from autograd.extend import vspace +from autograd.wrap_util import unary_to_nary +from numpy.typing import ArrayLike + +__all__ = [ + "value_and_grad", +] + + +@unary_to_nary +def value_and_grad( + fun: Callable, x: ArrayLike, *, has_aux: bool = False +) -> tuple[tuple[float, ArrayLike], Any]: + """Returns a function that returns both value and gradient. + + This function wraps and extends autograd's 'value_and_grad' function by adding + support for auxiliary data. + + Parameters + ---------- + fun : Callable + The function to differentiate. Should take a single argument and return + a scalar value, or a tuple where the first element is a scalar value if has_aux is True. + x : ArrayLike + The point at which to evaluate the function and its gradient. + has_aux : bool = False + If True, the function returns auxiliary data as the second element of a tuple. + + Returns + ------- + tuple[tuple[float, ArrayLike], Any] + A tuple containing: + - A tuple with the function value (float) and its gradient (ArrayLike) + - The auxiliary data returned by the function (if has_aux is True) + + Raises + ------ + TypeError + If the function does not return a scalar value. + + Notes + ----- + This function uses autograd for automatic differentiation. If the function + does not return auxiliary data (has_aux is False), it delegates to autograd's + value_and_grad function. The main extension is the support for auxiliary data + when has_aux is True. + """ + if not has_aux: + return value_and_grad_ag(fun)(x) + + vjp, (ans, aux) = make_vjp(lambda x: atuple(fun(x)), x) + + if not vspace(ans).size == 1: + raise TypeError( + "value_and_grad only applies to real scalar-output " + "functions. Try jacobian, elementwise_grad or " + "holomorphic_grad." + ) + + return (ans, vjp((vspace(ans).ones(), None))), aux