diff --git a/CHANGELOG.md b/CHANGELOG.md index d03f110677..0455105b3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation. + ### Fixed - Error when loading a previously run `Batch` or `ComponentModeler` containing custom data. diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index d811d7363c..88999a6662 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -12,9 +12,9 @@ import numpy as np import pytest import tidy3d as td +import xarray as xr from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.web import run_async -from tidy3d.web.api.autograd.autograd import run +from tidy3d.web import run, run_async from ..utils import SIM_FULL, AssertLogLevel, run_emulated @@ -53,6 +53,7 @@ WVL = 1.0 FREQ0 = td.C_0 / WVL +FREQS = [0.9 * FREQ0, FREQ0, FREQ0 * 1.1] # sim sizes LZ = 7 * WVL @@ -154,8 +155,10 @@ def run_async_emulated(simulations, **kwargs): def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: """Make a dictionary of the structures given the parameters.""" + np.random.seed(0) + vector = np.random.random(N_PARAMS) - 0.5 - vector /= np.linalg.norm(vector) + vector = vector / np.linalg.norm(vector) # static components box = td.Box(center=(0, 0, 0), size=(1, 1, 1)) @@ -411,7 +414,8 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None: if TEST_POLYSLAB_SPEED: args = [("polyslab", "mode")] -# args = [("geo_group", "mode")] + +# args = [("custom_med", "mode")] def get_functions(structure_key: str, monitor_key: str) -> typing.Callable: @@ -612,33 +616,13 @@ def objective(*params): ag.grad(objective)(params0) -def test_warning_no_adjoint_sources(log_capture, monkeypatch, use_emulated_run): - """Make sure we get the right warning with no adjoint sources, and no error.""" - - monitor_key = "mode" - structure_key = "size_element" - monitor, postprocess = make_monitors()[monitor_key] - - def make_sim(*args): - structure = make_structures(*args)[structure_key] - return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor]) - - def objective(*args): - """Objective function.""" - sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) - value = postprocess(data, data[monitor_key]) - return value - - monkeypatch.setattr(td.SimulationData, "make_adjoint_sources", lambda *args, **kwargs: []) - - with AssertLogLevel(log_capture, "WARNING", contains_str="No adjoint sources"): - ag.grad(objective)(params0) - - def test_web_failure_handling(log_capture, monkeypatch, use_emulated_run, use_emulated_run_async): """Test what happens when autograd run pipeline fails.""" + def fail(*args, **kwargs): + """Just raise an exception.""" + raise ValueError("test") + monitor_key = "mode" structure_key = "size_element" monitor, postprocess = make_monitors()[monitor_key] @@ -650,14 +634,10 @@ def make_sim(*args): def objective(*args): """Objective function.""" sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) + data = run(sim, task_name=None, verbose=False) value = postprocess(data, data[monitor_key]) return value - def fail(*args, **kwargs): - """Just raise an exception.""" - raise ValueError("test") - """ if autograd run raises exception, raise a warning and continue with regular .""" monkeypatch.setattr(td.web.api.autograd.autograd, "_run", fail) @@ -1008,3 +988,179 @@ def f(x): * no copy : 16 sec * no to_static(): 13 sec """ + +FREQ1 = FREQ0 * 1.1 + +mnt_single = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0], + name="single", +) + +mnt_multi = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0, FREQ1], + name="multi", +) + + +def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + def objective(params): + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi], + ) + data = run(sim, task_name="multifreq_test") + return postprocess_fn(data) + + return objective + + +def get_amps(sim_data: td.SimulationData, mnt_name: str) -> xr.DataArray: + return sim_data[mnt_name].amps + + +def power(amps: xr.DataArray) -> float: + """Reduce a selected DataArray into just a float for objective function.""" + return anp.sum(anp.abs(amps.values) ** 2) + + +def postprocess_0_src(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 0 adjoint sources.""" + return 0.0 + + +def compute_grad(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + objective = make_objective(postprocess_fn, structure_key=structure_key) + params = params0 + 1.0 # +1 is to avoid a warning in size_element with value 0 + return ag.grad(objective)(params) + + +def check_1_src_single(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + return power(amps) + + +def check_2_src_single(log_capture, structure_key): + def postprocess_2_src_single(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0) + return power(amps) + + +def check_1_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps) + + +def check_2_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, f=FREQ1) + return power(amps) + + +def check_2_src_both(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps_single) + power(amps_multi) + + +def check_1_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ1) + return power(amps_single) + power(amps_multi) + + +def check_2_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps_single) + power(amps_multi) + + +def check_1_src_broadband(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 broadband adjoint sources with many freqs.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + +MULT_FREQ_TEST_CASES = dict( + src_1_freq_1=check_1_src_single, + src_2_freq_1=check_2_src_single, + src_1_freq_2=check_1_src_multi, + src_2_freq_1_mon_1=check_1_src_multi, + src_2_freq_1_mon_2=check_2_src_both, + src_2_freq_2_mon_1=check_1_multisrc, + src_2_freq_2_mon_2=check_2_multisrc, + src_1_freq_2_broadband=check_1_src_broadband, +) + +checks = list(MULT_FREQ_TEST_CASES.items()) + + +@pytest.mark.parametrize("label, check_fn", checks) +@pytest.mark.parametrize("structure_key", structure_keys_) +def test_multi_freq_edge_cases(log_capture, use_emulated_run, structure_key, label, check_fn): + # test multi-frequency adjoint handling + check_fn(structure_key=structure_key, log_capture=log_capture) + + +@pytest.mark.parametrize("structure_key", structure_keys_) +def test_multi_frequency_equivalence(use_emulated_run, structure_key): + """Test an objective function through tidy3d autograd.""" + + def objective_indi(params, structure_key) -> float: + power_sum = 0.0 + + for f in mnt_multi.freqs: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + + sim_data = run(sim, task_name="multifreq_test") + amps_i = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=f) + power_i = power(amps_i) + power_sum = power_sum + power_i + + return power_sum + + def objective_multi(params, structure_key) -> float: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + sim_data = run(sim, task_name="multifreq_test") + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + params0_ = params0 + 1.0 + + J_indi = objective_indi(params0_, structure_key) + J_multi = objective_multi(params0_, structure_key) + + np.testing.assert_allclose(J_indi, J_multi) + + grad_indi = ag.grad(objective_indi)(params0_, structure_key=structure_key) + grad_multi = ag.grad(objective_multi)(params0_, structure_key=structure_key) + + assert not np.any(np.isclose(grad_indi, 0)) + assert not np.any(np.isclose(grad_multi, 0)) diff --git a/tests/utils.py b/tests/utils.py index 141a6882b6..94e40a7e0c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,8 @@ """ utilities shared between all tests """ np.random.seed(4) +# function used to generate the data for emulated runs +DATA_GEN_FN = np.random.random FREQS = np.array([1.90, 2.01, 2.2]) * 1e12 SIM_MONITORS = td.Simulation( @@ -880,7 +882,7 @@ def make_data( """make a random DataArray out of supplied coordinates and data_type.""" data_shape = [len(coords[k]) for k in data_array_type._dims] np.random.seed(1) - data = np.random.random(data_shape) + data = DATA_GEN_FN(data_shape) data = (1 + 0.5j) * data if is_complex else data data = gaussian_filter(data, sigma=1.0) # smooth out the data a little so it isnt random @@ -939,7 +941,7 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData: index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes) index_data_shape = (len(index_coords["f"]), len(index_coords["mode_index"])) index_data = ModeIndexDataArray( - (1 + 1j) * np.random.random(index_data_shape), coords=index_coords + (1 + 1j) * DATA_GEN_FN(index_data_shape), coords=index_coords ) for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: coords = get_spatial_coords_dict(simulation, monitor, field_name) @@ -977,7 +979,7 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData: orders_x = np.linspace(-1, 1, 3) orders_y = np.linspace(-2, 2, 5) coords = dict(orders_x=orders_x, orders_y=orders_y, f=f) - values = np.random.random((len(orders_x), len(orders_y), len(f))) + values = DATA_GEN_FN((len(orders_x), len(orders_y), len(f))) data = td.DiffractionDataArray(values, coords=coords) field_data = {field: data for field in ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")} return td.DiffractionData(monitor=monitor, sim_size=(1, 1), bloch_vecs=(0, 0), **field_data) diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index f0107b4ddf..cac924500d 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -20,7 +20,7 @@ from ..file_util import replace_values from ..monitor import Monitor from ..simulation import Simulation -from ..source import Source +from ..source import GaussianPulse, Source from ..structure import Structure from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type from ..viz import add_ax_if_none, equal_aspect @@ -953,13 +953,18 @@ def source_spectrum_fn(freqs): def make_adjoint_sim( self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor] - ) -> Simulation: + ) -> tuple[Simulation, float]: """Make the adjoint simulation from the original simulation and the VJP-containing data.""" sim_original = self.simulation - # generate the adjoint sources - sources_adj = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + # generate the adjoint sources {mnt_name : list[Source]} + sources_adj_dict = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + adj_srcs = [] + for src_list in sources_adj_dict.values(): + adj_srcs += list(src_list) + + sources_adj, post_norm = self.process_adjoint_sources(adj_srcs=adj_srcs) # grab boundary conditions with flipped Bloch vectors (for adjoint) bc_adj = sim_original.boundary_spec.flipped_bloch_vecs @@ -969,6 +974,8 @@ def make_adjoint_sim( sources=sources_adj, boundary_spec=bc_adj, monitors=adjoint_monitors, + normalize_index=None, + # TODO: set a longer run time, depending on what the adjoint sources look like ) # set the ADJ grid spec wavelength to the original wavelength (for same meshing) @@ -978,27 +985,131 @@ def make_adjoint_sim( grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original) sim_adj_update_dict["grid_spec"] = grid_spec_adj - return sim_original.updated_copy(**sim_adj_update_dict) + return sim_original.updated_copy(**sim_adj_update_dict), post_norm - def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> list[Source]: + def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" - # TODO: determine if we can do multi-frequency sources - # map of index into 'self.data' to the list of datasets we need adjoint sources for adj_src_map = defaultdict(list) for _, index, dataset_name in data_vjp_paths: adj_src_map[index].append(dataset_name) - # gather a list of adjoint sources for every monitor data in the VJP that needs one - sources_adj_all = [] + # gather a dict of adjoint sources for every monitor data in the VJP that needs one + sources_adj_all = defaultdict(list) for data_index, dataset_names in adj_src_map.items(): mnt_data = self.data[data_index] sources_adj = mnt_data.make_adjoint_sources(dataset_names=dataset_names) - sources_adj_all += sources_adj + sources_adj_all[mnt_data.monitor.name] = sources_adj return sources_adj_all + @staticmethod + def get_amp(src_time: GaussianPulse) -> complex: + """grab the complex amplitude from a ``SourceTime``.""" + mag = src_time.amplitude + phase = np.exp(1j * src_time.phase) + return mag * phase + + @staticmethod + def set_amp(src_time: GaussianPulse, amp: complex) -> GaussianPulse: + """set the complex amplitude of a ``SourceTime``.""" + amplitude = abs(amp) + phase = np.angle(amp) + return src_time.updated_copy(amplitude=amplitude, phase=phase) + + def process_adjoint_sources(self, adj_srcs: list[Source]) -> tuple[list[Source], float]: + """Compute list of final sources along with a post run normalization for adj fields.""" + + # fwidth of forward pass, try as default for adjoint + normalize_index_fwd = self.simulation.normalize_index or 0 + fwidth_adj = self.simulation.sources[normalize_index_fwd].source_time.fwidth + + # dictionary mapping unique spatial dependence of each Source to list of time-dependencies + json_to_sources = defaultdict(None) + spatial_to_src_times = defaultdict(list) + for src in adj_srcs: + src_spatial_json = src.json(exclude={"source_time"}) + json_to_sources[src_spatial_json] = src + spatial_to_src_times[src_spatial_json].append(src.source_time) + + # new adjoint sources + new_adj_srcs = [] + for src_json, source_times in spatial_to_src_times.items(): + src = json_to_sources[src_json] + new_sources = self.correct_adjoint_sources( + src=src, fwidth=fwidth_adj, source_times=source_times + ) + new_adj_srcs += new_sources + + # compute amplitudes of each adjoint source, and the norm + adj_src_amps = [] + for src in new_adj_srcs: + amp = self.get_amp(src.source_time) + adj_src_amps.append(amp) + norm_amps = np.linalg.norm(adj_src_amps) + + # normalize all of the adjoint sources by this and return the normalization term used + adj_srcs_norm = [] + for src in new_adj_srcs: + amp = self.get_amp(src.source_time) + src_time_norm = self.set_amp(src_time=src.source_time, amp=amp / norm_amps) + src_nrm = src.updated_copy(source_time=src_time_norm) + adj_srcs_norm.append(src_nrm) + + return adj_srcs_norm, norm_amps + + def correct_adjoint_sources( + self, src: Source, fwidth: float, source_times: list[GaussianPulse] + ) -> [Source]: + """Corret a set of spectrally overlapping adjoint sources to give correct E_adj.""" + + freqs = [st.freq0 for st in source_times] + times = self.simulation.tmesh + dt = self.simulation.dt + + def get_spectrum(source_time: GaussianPulse, freqs: list[float]) -> complex: + """Get the spectrum of a source time at a given frequency.""" + return source_time.spectrum(times=times, freqs=freqs, dt=dt) + + # compute matrix coupling the spectra of Gaussian pulses centered at each adjoint freq + def get_coupling_matrix(fwidth: float) -> np.ndarray: + """Matrix coupling the spectra of Gaussian pulses centered at each adjoint freq.""" + + return np.array( + [ + get_spectrum( + source_time=GaussianPulse(freq0=source_time.freq0, fwidth=fwidth), + freqs=freqs, + ) + for source_time in source_times + ] + ).T + + # compute the corrected set of amps to inject at each freq to take coupling into account + def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]: + J_coupling = get_coupling_matrix(fwidth=fwidth) + amps_adj = np.array([self.get_amp(src_time) for src_time in source_times]) + amps_adj_new, *info = np.linalg.lstsq(J_coupling, amps_adj, rcond=None) + return amps_adj_new, 0.0 + + # get the corrected amplitudes + fwidth_adj = fwidth + amps_corrected, _ = get_amps_corrected(fwidth_adj) + # TODO: if the fit is bad, reduce fwidth_adj and try again + + # construct the new adjoint sources with the corrected amplitudes + src_times_corrected = [ + self.set_amp(src_time=src_time, amp=amp).updated_copy(fwidth=fwidth_adj) + for src_time, amp in zip(source_times, amps_corrected) + ] + srcs_corrected = [] + for src_time in src_times_corrected: + src_new = src.updated_copy(source_time=src_time) + srcs_corrected.append(src_new) + + return srcs_corrected + def get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType: """Grab the field or permittivity data for a given structure index.""" diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 79258b0a91..4b5fd6bb4e 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1150,7 +1150,7 @@ def derivative_eps_complex_volume( ) vjp_value += vjp_value_fld - return vjp_value + return vjp_value.sum("f") class AbstractCustomMedium(AbstractMedium, ABC): @@ -1395,7 +1395,7 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).sum("f") vjp_array = np.array(E_der_dim_interp.values).astype(complex) vjp_array = vjp_array.reshape(eps_data.shape) @@ -2551,8 +2551,11 @@ def _derivative_field_cmp( eps_data: PermittivityDataset, dim: str, ) -> np.ndarray: - coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1} - dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp} + """Compute derivative with respect to the ``dim`` components within the custom medium.""" + + coords_interp = {key: eps_data.coords[key] for key in "xyz"} + coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} + dims_sum = [dim for dim in "xyz" if dim not in coords_interp] # compute sizes along each of the interpolation dimensions sizes_list = [] @@ -2581,8 +2584,11 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) - vjp_array = np.array(E_der_dim_interp.values).astype(complex) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real + E_der_dim_interp = E_der_dim_interp.sum("f") + + vjp_array = np.array(E_der_dim_interp.values).astype(float) + vjp_array = vjp_array.reshape(eps_data.shape) # multiply by volume elements (if possible, being defensive here..) diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 7a35e8177e..c2b1f69406 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -3254,14 +3254,6 @@ def freqs_adjoint(self) -> list[float]: if isinstance(mnt, FreqMonitor): freqs.update(mnt.freqs) freqs = sorted(freqs) - - if len(freqs) > 1: - raise ValueError( - "Only the same, single frequency is supported in all monitors " - "when using autograd differentiation. " - f"Found {len(freqs)} distinct frequencies in the monitors." - ) - return freqs """ Accounting """ diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 40a8bd079e..164b29e388 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -2,6 +2,7 @@ import traceback import typing +from collections import defaultdict import numpy as np from autograd.builtins import dict as dict_ag @@ -485,28 +486,13 @@ def _run_bwd( def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" - sim_adj = setup_adj( + sim_adj, post_norm = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) - # no adjoint sources, no gradient for you :( - if not len(sim_adj.sources): - td.log.warning( - "No adjoint sources generated. " - "There is likely zero output in the data, or you have no traceable monitors. " - "As a result, the 'SimulationData' returned has no contribution to the gradient. " - "Skipping the adjoint simulation. " - "If this is unexpected, please double check the post-processing function to ensure " - "there is a path from the 'SimulationData' to the objective function return value." - ) - - # TODO: add a test for this - # construct a VJP of all zeros for all tracers in the original simulation - return {path: 0 * value for path, value in sim_fields_original.items()} - # run adjoint simulation task_name_adj = str(task_name) + "_adjoint" sim_data_adj = _run_tidy3d(sim_adj, task_name=task_name_adj, **run_kwargs) @@ -516,6 +502,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm=post_norm, ) return vjp @@ -548,19 +535,21 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd task_names_adj = {task_name + "_adjoint" for task_name in task_names} sims_adj = {} + post_norm_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): data_fields_vjp = data_fields_dict_vjp[task_name] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_adj = setup_adj( + sim_adj, post_norm = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) sims_adj[task_name_adj] = sim_adj + post_norm_dict[task_name_adj] = post_norm # TODO: handle case where no adjoint sources? @@ -570,15 +559,16 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_vjp_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): sim_data_adj = batch_data_adj[task_name_adj] + post_norm = post_norm_dict[task_name_adj] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_fields_vjp = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm=post_norm, ) sim_fields_vjp_dict[task_name] = sim_fields_vjp @@ -592,7 +582,7 @@ def setup_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, -) -> td.Simulation: +) -> tuple[td.Simulation, float]: """Construct an adjoint simulation from a set of data_fields for the VJP.""" td.log.info("Running custom vjp (adjoint) pipeline.") @@ -607,13 +597,13 @@ def setup_adj( # make adjoint simulation from that SimulationData data_vjp_paths = set(data_fields_vjp.keys()) - sim_adj = sim_data_vjp.make_adjoint_sim( + sim_adj, post_norm = sim_data_vjp.make_adjoint_sim( data_vjp_paths=data_vjp_paths, adjoint_monitors=sim_data_fwd.simulation.monitors ) td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.") - return sim_adj + return sim_adj, post_norm def postprocess_adj( @@ -621,17 +611,15 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, + post_norm: float, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" # map of index into 'structures' to the list of paths we need vjps for - sim_vjp_map = {} + sim_vjp_map = defaultdict(list) for _, structure_index, *structure_path in sim_fields_original.keys(): structure_path = tuple(structure_path) - if structure_index in sim_vjp_map: - sim_vjp_map[structure_index].append(structure_path) - else: - sim_vjp_map[structure_index] = [structure_path] + sim_vjp_map[structure_index].append(structure_path) # store the derivative values given the forward and adjoint data sim_fields_vjp = {} @@ -642,6 +630,10 @@ def postprocess_adj( fld_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="fld") eps_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="eps") + # post normalize the adjoint fields if a single, broadband source + fwd_flds_normed = {key: val * post_norm for key, val in fld_adj.field_components.items()} + fld_adj = fld_adj.updated_copy(**fwd_flds_normed) + # maps of the E_fwd * E_adj and D_fwd * D_adj, each as as td.FieldData & 'Ex', 'Ey', 'Ez' der_maps = get_derivative_maps( fld_fwd=fld_fwd, eps_fwd=eps_fwd, fld_adj=fld_adj, eps_adj=eps_adj @@ -655,8 +647,7 @@ def postprocess_adj( # todo: handle multi-frequency, move to a property? frequencies = {src.source_time.freq0 for src in sim_data_adj.simulation.sources} frequencies = list(frequencies) - assert len(frequencies) == 1, "Multiple adjoint freqs found" - freq_adj = frequencies[0] + freq_adj = frequencies[0] or None eps_in = np.mean(structure.medium.eps_model(freq_adj)) eps_out = np.mean(sim_data_orig.simulation.medium.eps_model(freq_adj))