Skip to content

Commit

Permalink
broadband adjoint support for autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Jul 12, 2024
1 parent 6845aa4 commit d279e19
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 87 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation.

### Fixed
- Error when loading a previously run `Batch` or `ComponentModeler` containing custom data.

Expand Down
222 changes: 189 additions & 33 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import numpy as np
import pytest
import tidy3d as td
import xarray as xr
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
from tidy3d.web import run_async
from tidy3d.web.api.autograd.autograd import run
from tidy3d.web import run, run_async

from ..utils import SIM_FULL, AssertLogLevel, run_emulated

Expand Down Expand Up @@ -53,6 +53,7 @@

WVL = 1.0
FREQ0 = td.C_0 / WVL
FREQS = [0.9 * FREQ0, FREQ0, FREQ0 * 1.1]

# sim sizes
LZ = 7 * WVL
Expand Down Expand Up @@ -154,8 +155,10 @@ def run_async_emulated(simulations, **kwargs):
def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
"""Make a dictionary of the structures given the parameters."""

np.random.seed(0)

vector = np.random.random(N_PARAMS) - 0.5
vector /= np.linalg.norm(vector)
vector = vector / np.linalg.norm(vector)

# static components
box = td.Box(center=(0, 0, 0), size=(1, 1, 1))
Expand Down Expand Up @@ -411,7 +414,8 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None:
if TEST_POLYSLAB_SPEED:
args = [("polyslab", "mode")]

# args = [("geo_group", "mode")]

# args = [("custom_med", "mode")]


def get_functions(structure_key: str, monitor_key: str) -> typing.Callable:
Expand Down Expand Up @@ -612,33 +616,13 @@ def objective(*params):
ag.grad(objective)(params0)


def test_warning_no_adjoint_sources(log_capture, monkeypatch, use_emulated_run):
"""Make sure we get the right warning with no adjoint sources, and no error."""

monitor_key = "mode"
structure_key = "size_element"
monitor, postprocess = make_monitors()[monitor_key]

def make_sim(*args):
structure = make_structures(*args)[structure_key]
return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor])

def objective(*args):
"""Objective function."""
sim = make_sim(*args)
data = run(sim, task_name="autograd_test", verbose=False)
value = postprocess(data, data[monitor_key])
return value

monkeypatch.setattr(td.SimulationData, "make_adjoint_sources", lambda *args, **kwargs: [])

with AssertLogLevel(log_capture, "WARNING", contains_str="No adjoint sources"):
ag.grad(objective)(params0)


def test_web_failure_handling(log_capture, monkeypatch, use_emulated_run, use_emulated_run_async):
"""Test what happens when autograd run pipeline fails."""

def fail(*args, **kwargs):
"""Just raise an exception."""
raise ValueError("test")

monitor_key = "mode"
structure_key = "size_element"
monitor, postprocess = make_monitors()[monitor_key]
Expand All @@ -650,14 +634,10 @@ def make_sim(*args):
def objective(*args):
"""Objective function."""
sim = make_sim(*args)
data = run(sim, task_name="autograd_test", verbose=False)
data = run(sim, task_name=None, verbose=False)
value = postprocess(data, data[monitor_key])
return value

def fail(*args, **kwargs):
"""Just raise an exception."""
raise ValueError("test")

""" if autograd run raises exception, raise a warning and continue with regular ."""

monkeypatch.setattr(td.web.api.autograd.autograd, "_run", fail)
Expand Down Expand Up @@ -1008,3 +988,179 @@ def f(x):
* no copy : 16 sec
* no to_static(): 13 sec
"""

FREQ1 = FREQ0 * 1.1

mnt_single = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
mode_spec=td.ModeSpec(num_modes=2),
freqs=[FREQ0],
name="single",
)

mnt_multi = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
mode_spec=td.ModeSpec(num_modes=2),
freqs=[FREQ0, FREQ1],
name="multi",
)


def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable:
def objective(params):
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi],
)
data = run(sim, task_name="multifreq_test")
return postprocess_fn(data)

return objective


def get_amps(sim_data: td.SimulationData, mnt_name: str) -> xr.DataArray:
return sim_data[mnt_name].amps


def power(amps: xr.DataArray) -> float:
"""Reduce a selected DataArray into just a float for objective function."""
return anp.sum(anp.abs(amps.values) ** 2)


def postprocess_0_src(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 0 adjoint sources."""
return 0.0


def compute_grad(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable:
objective = make_objective(postprocess_fn, structure_key=structure_key)
params = params0 + 1.0 # +1 is to avoid a warning in size_element with value 0
return ag.grad(objective)(params)


def check_1_src_single(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 adjoint sources."""
amps = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
return power(amps)


def check_2_src_single(log_capture, structure_key):
def postprocess_2_src_single(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps = get_amps(sim_data, "single").sel(mode_index=0)
return power(amps)


def check_1_src_multi(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 adjoint sources."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0)
return power(amps)


def check_2_src_multi(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, f=FREQ1)
return power(amps)


def check_2_src_both(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 2 different adjoint sources."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0)
return power(amps_single) + power(amps_multi)


def check_1_multisrc(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should raise ValueError because diff sources, diff freqs."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ1)
return power(amps_single) + power(amps_multi)


def check_2_multisrc(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should raise ValueError because diff sources, diff freqs."""
amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+")
amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps_single) + power(amps_multi)


def check_1_src_broadband(log_capture, structure_key):
def postprocess(sim_data: td.SimulationData) -> float:
"""Postprocess function that should return 1 broadband adjoint sources with many freqs."""
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps)


MULT_FREQ_TEST_CASES = dict(
src_1_freq_1=check_1_src_single,
src_2_freq_1=check_2_src_single,
src_1_freq_2=check_1_src_multi,
src_2_freq_1_mon_1=check_1_src_multi,
src_2_freq_1_mon_2=check_2_src_both,
src_2_freq_2_mon_1=check_1_multisrc,
src_2_freq_2_mon_2=check_2_multisrc,
src_1_freq_2_broadband=check_1_src_broadband,
)

checks = list(MULT_FREQ_TEST_CASES.items())


@pytest.mark.parametrize("label, check_fn", checks)
@pytest.mark.parametrize("structure_key", structure_keys_)
def test_multi_freq_edge_cases(log_capture, use_emulated_run, structure_key, label, check_fn):
# test multi-frequency adjoint handling
check_fn(structure_key=structure_key, log_capture=log_capture)


@pytest.mark.parametrize("structure_key", structure_keys_)
def test_multi_frequency_equivalence(use_emulated_run, structure_key):
"""Test an objective function through tidy3d autograd."""

def objective_indi(params, structure_key) -> float:
power_sum = 0.0

for f in mnt_multi.freqs:
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_multi],
)

sim_data = run(sim, task_name="multifreq_test")
amps_i = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=f)
power_i = power(amps_i)
power_sum = power_sum + power_i

return power_sum

def objective_multi(params, structure_key) -> float:
structure_traced = make_structures(params)[structure_key]
sim = SIM_BASE.updated_copy(
structures=[structure_traced],
monitors=list(SIM_BASE.monitors) + [mnt_multi],
)
sim_data = run(sim, task_name="multifreq_test")
amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+")
return power(amps)

params0_ = params0 + 1.0

J_indi = objective_indi(params0_, structure_key)
J_multi = objective_multi(params0_, structure_key)

np.testing.assert_allclose(J_indi, J_multi)

grad_indi = ag.grad(objective_indi)(params0_, structure_key=structure_key)
grad_multi = ag.grad(objective_multi)(params0_, structure_key=structure_key)

assert not np.any(np.isclose(grad_indi, 0))
assert not np.any(np.isclose(grad_multi, 0))
8 changes: 5 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
""" utilities shared between all tests """
np.random.seed(4)

# function used to generate the data for emulated runs
DATA_GEN_FN = np.random.random

FREQS = np.array([1.90, 2.01, 2.2]) * 1e12
SIM_MONITORS = td.Simulation(
Expand Down Expand Up @@ -880,7 +882,7 @@ def make_data(
"""make a random DataArray out of supplied coordinates and data_type."""
data_shape = [len(coords[k]) for k in data_array_type._dims]
np.random.seed(1)
data = np.random.random(data_shape)
data = DATA_GEN_FN(data_shape)

data = (1 + 0.5j) * data if is_complex else data
data = gaussian_filter(data, sigma=1.0) # smooth out the data a little so it isnt random
Expand Down Expand Up @@ -939,7 +941,7 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData:
index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes)
index_data_shape = (len(index_coords["f"]), len(index_coords["mode_index"]))
index_data = ModeIndexDataArray(
(1 + 1j) * np.random.random(index_data_shape), coords=index_coords
(1 + 1j) * DATA_GEN_FN(index_data_shape), coords=index_coords
)
for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
coords = get_spatial_coords_dict(simulation, monitor, field_name)
Expand Down Expand Up @@ -977,7 +979,7 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData:
orders_x = np.linspace(-1, 1, 3)
orders_y = np.linspace(-2, 2, 5)
coords = dict(orders_x=orders_x, orders_y=orders_y, f=f)
values = np.random.random((len(orders_x), len(orders_y), len(f)))
values = DATA_GEN_FN((len(orders_x), len(orders_y), len(f)))
data = td.DiffractionDataArray(values, coords=coords)
field_data = {field: data for field in ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")}
return td.DiffractionData(monitor=monitor, sim_size=(1, 1), bloch_vecs=(0, 0), **field_data)
Expand Down
1 change: 1 addition & 0 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, data, *args, **kwargs):
# initialize with untraced data
super().__init__(getval(data), *args, **kwargs)
# and put tracers in .attrs

if isbox(data):
self.attrs[AUTOGRAD_KEY] = data

Expand Down
Loading

0 comments on commit d279e19

Please sign in to comment.