From aa0ef00400e5acd92caa006ea2a89e55082b0e9d Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Mon, 17 Jun 2024 04:55:28 +0200 Subject: [PATCH] broadband adjoint support for autograd --- CHANGELOG.md | 2 + docs/notebooks | 2 +- tests/test_components/test_autograd.py | 262 +++++++++++++++++++++---- tests/utils.py | 8 +- tidy3d/components/data/monitor_data.py | 40 ++-- tidy3d/components/data/sim_data.py | 242 +++++++++++++++++++++-- tidy3d/components/medium.py | 18 +- tidy3d/components/simulation.py | 8 - tidy3d/web/api/autograd/autograd.py | 52 +++-- 9 files changed, 522 insertions(+), 112 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77815cc313..6ab383473b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Mode solver plugin now supports 'EMESimulation'. - `TriangleMesh` class: automatic removal of zero-area faces, and functions `fill_holes` and `fix_winding` to attempt mesh repair. +- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation. +- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation when the objective depends on a single port. ### Fixed - Error when loading a previously run `Batch` or `ComponentModeler` containing custom data. diff --git a/docs/notebooks b/docs/notebooks index 21a53743b6..d917addfb1 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 21a53743b6f8a3b9c6a5d388c5e8a6352a42adc2 +Subproject commit d917addfb1cdabe12dba14c96b0c6c056b663fa6 diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index d811d7363c..e72595a7e8 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -12,9 +12,9 @@ import numpy as np import pytest import tidy3d as td +import xarray as xr from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.web import run_async -from tidy3d.web.api.autograd.autograd import run +from tidy3d.web import run, run_async from ..utils import SIM_FULL, AssertLogLevel, run_emulated @@ -53,6 +53,7 @@ WVL = 1.0 FREQ0 = td.C_0 / WVL +FREQS = [FREQ0] # sim sizes LZ = 7 * WVL @@ -154,8 +155,10 @@ def run_async_emulated(simulations, **kwargs): def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: """Make a dictionary of the structures given the parameters.""" + np.random.seed(0) + vector = np.random.random(N_PARAMS) - 0.5 - vector /= np.linalg.norm(vector) + vector = vector / np.linalg.norm(vector) # static components box = td.Box(center=(0, 0, 0), size=(1, 1, 1)) @@ -411,7 +414,8 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None: if TEST_POLYSLAB_SPEED: args = [("polyslab", "mode")] -# args = [("geo_group", "mode")] + +# args = [("custom_med", "mode")] def get_functions(structure_key: str, monitor_key: str) -> typing.Callable: @@ -612,33 +616,13 @@ def objective(*params): ag.grad(objective)(params0) -def test_warning_no_adjoint_sources(log_capture, monkeypatch, use_emulated_run): - """Make sure we get the right warning with no adjoint sources, and no error.""" - - monitor_key = "mode" - structure_key = "size_element" - monitor, postprocess = make_monitors()[monitor_key] - - def make_sim(*args): - structure = make_structures(*args)[structure_key] - return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor]) - - def objective(*args): - """Objective function.""" - sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) - value = postprocess(data, data[monitor_key]) - return value - - monkeypatch.setattr(td.SimulationData, "make_adjoint_sources", lambda *args, **kwargs: []) - - with AssertLogLevel(log_capture, "WARNING", contains_str="No adjoint sources"): - ag.grad(objective)(params0) - - def test_web_failure_handling(log_capture, monkeypatch, use_emulated_run, use_emulated_run_async): """Test what happens when autograd run pipeline fails.""" + def fail(*args, **kwargs): + """Just raise an exception.""" + raise ValueError("test") + monitor_key = "mode" structure_key = "size_element" monitor, postprocess = make_monitors()[monitor_key] @@ -650,14 +634,10 @@ def make_sim(*args): def objective(*args): """Objective function.""" sim = make_sim(*args) - data = run(sim, task_name="autograd_test", verbose=False) + data = run(sim, task_name=None, verbose=False) value = postprocess(data, data[monitor_key]) return value - def fail(*args, **kwargs): - """Just raise an exception.""" - raise ValueError("test") - """ if autograd run raises exception, raise a warning and continue with regular .""" monkeypatch.setattr(td.web.api.autograd.autograd, "_run", fail) @@ -1008,3 +988,219 @@ def f(x): * no copy : 16 sec * no to_static(): 13 sec """ + +FREQ1 = FREQ0 * 1.6 + +mnt_single = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0], + name="single", +) + +mnt_multi = td.ModeMonitor( + size=(2, 2, 0), + center=(0, 0, LZ / 2 - WVL), + mode_spec=td.ModeSpec(num_modes=2), + freqs=[FREQ0, FREQ1], + name="multi", +) + + +def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + def objective(params): + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi], + ) + data = run(sim, task_name="multifreq_test") + return postprocess_fn(data) + + return objective + + +def get_amps(sim_data: td.SimulationData, mnt_name: str) -> xr.DataArray: + return sim_data[mnt_name].amps + + +def power(amps: xr.DataArray) -> float: + """Reduce a selected DataArray into just a float for objective function.""" + return anp.sum(anp.abs(amps.values) ** 2) + + +def postprocess_0_src(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 0 adjoint sources.""" + return 0.0 + + +def compute_grad(postprocess_fn: typing.Callable, structure_key: str) -> typing.Callable: + objective = make_objective(postprocess_fn, structure_key=structure_key) + params = params0 + 1.0 # +1 is to avoid a warning in size_element with value 0 + return ag.grad(objective)(params) + + +def check_1_src_single(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + return power(amps) + + return postprocess + + +def check_2_src_single(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "single").sel(mode_index=0) + return power(amps) + + return postprocess + + +def check_1_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps) + + return postprocess + + +def check_2_src_multi(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, f=FREQ1) + return power(amps) + + return postprocess + + +def check_2_src_both(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 2 different adjoint sources.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ0) + return power(amps_single) + power(amps_multi) + + return postprocess + + +def check_1_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=FREQ1) + return power(amps_single) + power(amps_multi) + + return postprocess + + +def check_2_multisrc(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should raise ValueError because diff sources, diff freqs.""" + amps_single = get_amps(sim_data, "single").sel(mode_index=0, direction="+") + amps_multi = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps_single) + power(amps_multi) + + return postprocess + + +def check_1_src_broadband(log_capture, structure_key): + def postprocess(sim_data: td.SimulationData) -> float: + """Postprocess function that should return 1 broadband adjoint sources with many freqs.""" + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + return postprocess + + +MULT_FREQ_TEST_CASES = dict( + src_1_freq_1=check_1_src_single, + src_2_freq_1=check_2_src_single, + src_1_freq_2=check_1_src_multi, + src_2_freq_1_mon_1=check_1_src_multi, + src_2_freq_1_mon_2=check_2_src_both, + src_2_freq_2_mon_1=check_1_multisrc, + src_2_freq_2_mon_2=check_2_multisrc, + src_1_freq_2_broadband=check_1_src_broadband, +) + +checks = list(MULT_FREQ_TEST_CASES.items()) + + +@pytest.mark.parametrize("label, check_fn", checks) +@pytest.mark.parametrize("structure_key", ("custom_med",)) +def test_multi_freq_edge_cases( + log_capture, use_emulated_run, structure_key, label, check_fn, monkeypatch +): + # test multi-frequency adjoint handling + + import tidy3d.components.data.sim_data as sd + + monkeypatch.setattr(sd, "RESIDUAL_CUTOFF_ADJOINT", 1) + reload(td) + + postprocess_fn = check_fn(structure_key=structure_key, log_capture=log_capture) + + def objective(params): + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_single, mnt_multi], + ) + data = run(sim, task_name="multifreq_test") + return postprocess_fn(data) + + if label == "src_2_freq_2_mon_2": + with pytest.raises(NotImplementedError): + g = ag.grad(objective)(params0) + else: + g = ag.grad(objective)(params0) + print(g) + + +@pytest.mark.parametrize("structure_key", structure_keys_) +def test_multi_frequency_equivalence(use_emulated_run, structure_key): + """Test an objective function through tidy3d autograd.""" + + def objective_indi(params, structure_key) -> float: + power_sum = 0.0 + + for f in mnt_multi.freqs: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + + sim_data = run(sim, task_name="multifreq_test") + amps_i = get_amps(sim_data, "multi").sel(mode_index=0, direction="+", f=f) + power_i = power(amps_i) + power_sum = power_sum + power_i + + return power_sum + + def objective_multi(params, structure_key) -> float: + structure_traced = make_structures(params)[structure_key] + sim = SIM_BASE.updated_copy( + structures=[structure_traced], + monitors=list(SIM_BASE.monitors) + [mnt_multi], + ) + sim_data = run(sim, task_name="multifreq_test") + amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") + return power(amps) + + params0_ = params0 + 1.0 + + J_indi = objective_indi(params0_, structure_key) + J_multi = objective_multi(params0_, structure_key) + + np.testing.assert_allclose(J_indi, J_multi) + + grad_indi = ag.grad(objective_indi)(params0_, structure_key=structure_key) + grad_multi = ag.grad(objective_multi)(params0_, structure_key=structure_key) + + assert not np.any(np.isclose(grad_indi, 0)) + assert not np.any(np.isclose(grad_multi, 0)) diff --git a/tests/utils.py b/tests/utils.py index 141a6882b6..94e40a7e0c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,8 @@ """ utilities shared between all tests """ np.random.seed(4) +# function used to generate the data for emulated runs +DATA_GEN_FN = np.random.random FREQS = np.array([1.90, 2.01, 2.2]) * 1e12 SIM_MONITORS = td.Simulation( @@ -880,7 +882,7 @@ def make_data( """make a random DataArray out of supplied coordinates and data_type.""" data_shape = [len(coords[k]) for k in data_array_type._dims] np.random.seed(1) - data = np.random.random(data_shape) + data = DATA_GEN_FN(data_shape) data = (1 + 0.5j) * data if is_complex else data data = gaussian_filter(data, sigma=1.0) # smooth out the data a little so it isnt random @@ -939,7 +941,7 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData: index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes) index_data_shape = (len(index_coords["f"]), len(index_coords["mode_index"])) index_data = ModeIndexDataArray( - (1 + 1j) * np.random.random(index_data_shape), coords=index_coords + (1 + 1j) * DATA_GEN_FN(index_data_shape), coords=index_coords ) for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: coords = get_spatial_coords_dict(simulation, monitor, field_name) @@ -977,7 +979,7 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData: orders_x = np.linspace(-1, 1, 3) orders_y = np.linspace(-2, 2, 5) coords = dict(orders_x=orders_x, orders_y=orders_y, f=f) - values = np.random.random((len(orders_x), len(orders_y), len(f))) + values = DATA_GEN_FN((len(orders_x), len(orders_y), len(f))) data = td.DiffractionDataArray(values, coords=coords) field_data = {field: data for field in ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")} return td.DiffractionData(monitor=monitor, sim_size=(1, 1), bloch_vecs=(0, 0), **field_data) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 3474d17a35..5b55ffccdd 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -126,7 +126,7 @@ def _updated(self, update: Dict) -> MonitorData: data_dict.update(update) return type(self).parse_obj(data_dict) - def make_adjoint_sources(self, dataset_names: list[str]) -> list[Source]: + def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[Source]: """Generate adjoint sources for this ``MonitorData`` instance.""" # TODO: if there's data in the MonitorData, but no adjoint source, then @@ -1018,16 +1018,16 @@ def to_source( ) def make_adjoint_sources( - self, dataset_names: list[str] + self, dataset_names: list[str], fwidth: float ) -> List[Union[CustomCurrentSource, PointDipole]]: """Converts a :class:`.FieldData` to a list of adjoint current or point sources.""" if np.allclose(self.monitor.size, 0): - return self.to_adjoint_point_sources() + return self.to_adjoint_point_sources(fwidth=fwidth) - return self.to_adjoint_field_sources() + return self.to_adjoint_field_sources(fwidth=fwidth) - def to_adjoint_point_sources(self) -> List[PointDipole]: + def to_adjoint_point_sources(self, fwidth: float) -> List[PointDipole]: """Create adjoint point dipole source if this field data contains one item.""" sources = [] @@ -1050,7 +1050,7 @@ def to_adjoint_point_sources(self) -> List[PointDipole]: polarization=polarization, source_time=GaussianPulse( freq0=freq0, - fwidth=freq0 / 10, # TODO: how to set this properly? + fwidth=fwidth, amplitude=abs(adj_amp), phase=adj_phase, ), @@ -1061,7 +1061,7 @@ def to_adjoint_point_sources(self) -> List[PointDipole]: return sources - def to_adjoint_field_sources(self) -> List[CustomCurrentSource]: + def to_adjoint_field_sources(self, fwidth: float) -> List[CustomCurrentSource]: """Create adjoint custom field sources if this field data has some dimensionality.""" sources = [] @@ -1111,7 +1111,7 @@ def shift_value(coords) -> float: size=source_geo.size, source_time=GaussianPulse( freq0=freq0, - fwidth=freq0 / 10, # TODO: how to set this properly? + fwidth=fwidth, ), current_dataset=dataset, interpolate=True, @@ -1752,14 +1752,14 @@ def to_dataframe(self) -> DataFrame: return dataset.drop_vars(drop).to_dataframe() - def make_adjoint_sources(self, dataset_names: list[str]) -> list[ModeSource]: + def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[ModeSource]: """Get all adjoint sources for the ``ModeMonitorData``.""" adjoint_sources = [] for name in dataset_names: if name == "amps": - adjoint_sources += self.make_adjoint_sources_amps() + adjoint_sources += self.make_adjoint_sources_amps(fwidth=fwidth) else: log.warning( f"Can't create adjoint source for 'ModeData.{type(self)}.{name}'. " @@ -1770,7 +1770,7 @@ def make_adjoint_sources(self, dataset_names: list[str]) -> list[ModeSource]: return adjoint_sources - def make_adjoint_sources_amps(self) -> list[ModeSource]: + def make_adjoint_sources_amps(self, fwidth: float) -> list[ModeSource]: """Generate adjoint sources for ``ModeMonitorData.amps``.""" coords = self.amps.coords @@ -1786,12 +1786,12 @@ def make_adjoint_sources_amps(self) -> list[ModeSource]: if self.get_amplitude(amp_single) == 0.0: continue - adjoint_source = self.adjoint_source_amp(amp=amp_single) + adjoint_source = self.adjoint_source_amp(amp=amp_single, fwidth=fwidth) adjoint_sources.append(adjoint_source) return adjoint_sources - def adjoint_source_amp(self, amp: DataArray) -> ModeSource: + def adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource: """Generate an adjoint ``ModeSource`` for a single amplitude.""" monitor = self.monitor @@ -1814,7 +1814,7 @@ def adjoint_source_amp(self, amp: DataArray) -> ModeSource: amplitude=abs(src_amp), phase=np.angle(src_amp), freq0=freq0, - fwidth=freq0 / 10, # TODO: how to set this properly? + fwidth=fwidth, ), mode_spec=monitor.mode_spec, size=monitor.size, @@ -2879,13 +2879,13 @@ def _make_dataset(self, fields: Tuple[np.ndarray, ...], keys: Tuple[str, ...]) - """ Autograd code """ - def make_adjoint_sources(self, dataset_names: list[str]) -> list[PlaneWave]: + def make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[PlaneWave]: """Get all adjoint sources for the ``DiffractionMonitor.amps``.""" # NOTE: everything just goes through `.amps`, any post-processing is encoded in E-fields - return self.make_adjoint_sources_amps() + return self.make_adjoint_sources_amps(fwidth=fwidth) - def make_adjoint_sources_amps(self) -> list[PlaneWave]: + def make_adjoint_sources_amps(self, fwidth: float) -> list[PlaneWave]: """Make adjoint sources for outputs that depend on DiffractionData.`amps`.""" amps = self.amps @@ -2912,13 +2912,13 @@ def make_adjoint_sources_amps(self) -> list[PlaneWave]: continue # compute a plane wave for this amplitude (if propagating / not None) - adjoint_source = self.adjoint_source_amp(amp=amp_single) + adjoint_source = self.adjoint_source_amp(amp=amp_single, fwidth=fwidth) if adjoint_source is not None: adjoint_sources.append(adjoint_source) return adjoint_sources - def adjoint_source_amp(self, amp: DataArray) -> PlaneWave: + def adjoint_source_amp(self, amp: DataArray, fwidth: float) -> PlaneWave: """Generate an adjoint ``PlaneWave`` for a single amplitude.""" monitor = self.monitor @@ -2962,7 +2962,7 @@ def adjoint_source_amp(self, amp: DataArray) -> PlaneWave: amplitude=abs(src_amp), phase=np.angle(src_amp), freq0=freq0, - fwidth=freq0 / 10, # TODO: how to set this properly? + fwidth=fwidth, ), direction=self.flip_direction(monitor.normal_dir), angle_theta=angle_theta, diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index f0107b4ddf..0038206217 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -20,7 +20,7 @@ from ..file_util import replace_values from ..monitor import Monitor from ..simulation import Simulation -from ..source import Source +from ..source import GaussianPulse, Source from ..structure import Structure from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type from ..viz import add_ax_if_none, equal_aspect @@ -36,6 +36,9 @@ # maps monitor type (string) to the class of the corresponding data DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes} +# residuals below this are considered good fits for broadband adjoint source creation +RESIDUAL_CUTOFF_ADJOINT = 0.5 + class AbstractYeeGridSimulationData(AbstractSimulationData, ABC): """Data from an :class:`.AbstractYeeGridSimulation` involving @@ -953,13 +956,18 @@ def source_spectrum_fn(freqs): def make_adjoint_sim( self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor] - ) -> Simulation: + ) -> tuple[Simulation, float]: """Make the adjoint simulation from the original simulation and the VJP-containing data.""" sim_original = self.simulation - # generate the adjoint sources - sources_adj = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + # generate the adjoint sources {mnt_name : list[Source]} + sources_adj_dict = self.make_adjoint_sources(data_vjp_paths=data_vjp_paths) + adj_srcs = [] + for src_list in sources_adj_dict.values(): + adj_srcs += list(src_list) + + sources_adj, post_norm, norm_source = self.process_adjoint_sources(adj_srcs=adj_srcs) # grab boundary conditions with flipped Bloch vectors (for adjoint) bc_adj = sim_original.boundary_spec.flipped_bloch_vecs @@ -971,6 +979,9 @@ def make_adjoint_sim( monitors=adjoint_monitors, ) + if not norm_source: + sim_adj_update_dict["normalize_index"] = None + # set the ADJ grid spec wavelength to the original wavelength (for same meshing) grid_spec_original = sim_original.grid_spec if sim_original.sources and grid_spec_original.wavelength is None: @@ -978,27 +989,234 @@ def make_adjoint_sim( grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original) sim_adj_update_dict["grid_spec"] = grid_spec_adj - return sim_original.updated_copy(**sim_adj_update_dict) + return sim_original.updated_copy(**sim_adj_update_dict), post_norm - def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> list[Source]: + def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" - # TODO: determine if we can do multi-frequency sources - # map of index into 'self.data' to the list of datasets we need adjoint sources for adj_src_map = defaultdict(list) for _, index, dataset_name in data_vjp_paths: adj_src_map[index].append(dataset_name) - # gather a list of adjoint sources for every monitor data in the VJP that needs one - sources_adj_all = [] + # gather a dict of adjoint sources for every monitor data in the VJP that needs one + sources_adj_all = defaultdict(list) for data_index, dataset_names in adj_src_map.items(): mnt_data = self.data[data_index] - sources_adj = mnt_data.make_adjoint_sources(dataset_names=dataset_names) - sources_adj_all += sources_adj + sources_adj = mnt_data.make_adjoint_sources( + dataset_names=dataset_names, fwidth=self.fwidth_adj + ) + sources_adj_all[mnt_data.monitor.name] = sources_adj return sources_adj_all + @staticmethod + def get_amp(src_time: GaussianPulse) -> complex: + """grab the complex amplitude from a ``SourceTime``.""" + mag = src_time.amplitude + phase = np.exp(1j * src_time.phase) + return mag * phase + + @staticmethod + def set_amp(src_time: GaussianPulse, amp: complex) -> GaussianPulse: + """set the complex amplitude of a ``SourceTime``.""" + amplitude = abs(amp) + phase = np.angle(amp) + return src_time.updated_copy(amplitude=amplitude, phase=phase) + + @property + def fwidth_adj(self) -> float: + # fwidth of forward pass, try as default for adjoint + normalize_index_fwd = self.simulation.normalize_index or 0 + return self.simulation.sources[normalize_index_fwd].source_time.fwidth + + def process_adjoint_sources( + self, adj_srcs: list[Source] + ) -> tuple[list[Source], Union[float, xr.DataArray], bool]: + """Compute list of final sources along with a post run normalization for adj fields.""" + + # dictionary mapping unique spatial dependence of each Source to list of time-dependencies + json_to_sources = defaultdict(None) + spatial_to_src_times = defaultdict(list) + for src in adj_srcs: + src_spatial_json = src.json(exclude={"source_time"}) + json_to_sources[src_spatial_json] = src + spatial_to_src_times[src_spatial_json].append(src.source_time) + + num_ports = len(spatial_to_src_times) + num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs}) + + # next, figure out which treatment / normalization to apply + if num_unique_freqs == 1: + log.info("adjoint source creation: one unique frequency, no normalization") + return adj_srcs, 1.0, True + + if num_ports == 1 and len(adj_srcs) == num_unique_freqs: + log.info("adjoint source creation: one spatial port detected") + adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs) + return adj_srcs, post_norm, True + + # if several spatial ports and several frequencies, try to fit + log.info("adjoint source creation: trying multifrequency fit.") + adj_srcs, post_norm = self.process_adjoint_sources_fit( + adj_srcs=adj_srcs, + spatial_to_src_times=spatial_to_src_times, + json_to_sources=json_to_sources, + ) + return adj_srcs, post_norm, False + + """ SIMPLE APPROACH """ + + def process_adjoint_sources_broadband( + self, adj_srcs: list[Source] + ) -> tuple[list[Source], xr.DataArray]: + """Process adjoint sources for the case of several sources at the same freq.""" + + src_broadband = self._make_broadband_source(adj_srcs=adj_srcs) + post_norm_amps = self._make_post_norm_amps(adj_srcs=adj_srcs) + + log.info( + "Several adjoint sources, from one monitor. " + "Only difference between them is the source time. " + "Constructing broadband adjoint source and performing post-run normalization " + f"of fields with {len(post_norm_amps)} frequencies." + ) + + return [src_broadband], post_norm_amps + + def _make_broadband_source(self, adj_srcs: list[Source], num_fwidth: float = 0.5) -> Source: + """Make a broadband source for a set of adjoint sources.""" + + source_index = self.simulation.normalize_index or 0 + src_time_base = self.simulation.sources[source_index].source_time.copy() + src_broadband = adj_srcs[0].updated_copy(source_time=src_time_base) + + return src_broadband + + @staticmethod + def _make_post_norm_amps(adj_srcs: list[Source]) -> xr.DataArray: + """Make a ``DataArray`` containing the complex amplitudes to multiply with adjoint field.""" + + freqs = [] + amps_complex = [] + for src in adj_srcs: + src_time = src.source_time + freqs.append(src_time.freq0) + amp_complex = src_time.amplitude * np.exp(1j * src_time.phase) + amps_complex.append(amp_complex) + + coords = dict(f=freqs) + amps_complex = np.array(amps_complex) + return xr.DataArray(amps_complex, coords=coords) + + """ FITTING APPROACH """ + + def process_adjoint_sources_fit( + self, + adj_srcs: list[Source], + spatial_to_src_times: dict[str, GaussianPulse], + json_to_sources: dict[str, list[Source]], + ) -> tuple[list[Source], float]: + """Process the adjoint sources using a least squared fit to the derivative data.""" + + raise NotImplementedError( + "Can't perform multi-frequency autograd with several adjoint sources yet. " + "In the meantime, please construct a single 'Simulation' per output data " + "(can be multi-frequency) and run in parallel using 'web.run_async'. For example, " + "if your problem has 'P' outuput ports, e.g. waveguides, please make a 'Simulation' " + "corresponding to the objective function contribution at each port." + ) + + # new adjoint sources + new_adj_srcs = [] + for src_json, source_times in spatial_to_src_times.items(): + src = json_to_sources[src_json] + new_sources = self.correct_adjoint_sources( + src=src, fwidth=self.fwidth_adj, source_times=source_times + ) + new_adj_srcs += new_sources + + # compute amplitudes of each adjoint source, and the norm + adj_src_amps = [] + for src in new_adj_srcs: + amp = self.get_amp(src.source_time) + adj_src_amps.append(amp) + norm_amps = np.linalg.norm(adj_src_amps) + + # normalize all of the adjoint sources by this and return the normalization term used + adj_srcs_norm = [] + for src in new_adj_srcs: + amp = self.get_amp(src.source_time) + src_time_norm = self.set_amp(src_time=src.source_time, amp=amp / norm_amps) + src_nrm = src.updated_copy(source_time=src_time_norm) + adj_srcs_norm.append(src_nrm) + + return adj_srcs_norm, norm_amps + + def correct_adjoint_sources( + self, src: Source, fwidth: float, source_times: list[GaussianPulse] + ) -> [Source]: + """Corret a set of spectrally overlapping adjoint sources to give correct E_adj.""" + + freqs = [st.freq0 for st in source_times] + times = self.simulation.tmesh + dt = self.simulation.dt + + def get_spectrum(source_time: GaussianPulse, freqs: list[float]) -> complex: + """Get the spectrum of a source time at a given frequency.""" + return source_time.spectrum(times=times, freqs=freqs, dt=dt) + + # compute matrix coupling the spectra of Gaussian pulses centered at each adjoint freq + def get_coupling_matrix(fwidth: float) -> np.ndarray: + """Matrix coupling the spectra of Gaussian pulses centered at each adjoint freq.""" + + return np.array( + [ + get_spectrum( + source_time=GaussianPulse(freq0=source_time.freq0, fwidth=fwidth), + freqs=freqs, + ) + for source_time in source_times + ] + ).T + + amps_adj = np.array([self.get_amp(src_time) for src_time in source_times]) + + # compute the corrected set of amps to inject at each freq to take coupling into account + def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]: + """New set of new adjoint source amps that generate the desired response at each f.""" + J_coupling = get_coupling_matrix(fwidth=fwidth) + + amps_adj_new, *info = np.linalg.lstsq(J_coupling, amps_adj, rcond=None) + # amps_adj_new = np.linalg.solve(J_coupling, amps_adj) + residual = J_coupling @ amps_adj_new - amps_adj + residual_norm = np.linalg.norm(residual) / np.linalg.norm(amps_adj) + return amps_adj_new, residual_norm + + # get the corrected amplitudes + amps_corrected, res_norm = get_amps_corrected(self.fwidth_adj) + + if res_norm > RESIDUAL_CUTOFF_ADJOINT: + raise ValueError( + f"Residual of {res_norm:.5e} found when trying to fit adjoint source spectrum. " + f"This is above our accuracy cutoff of {RESIDUAL_CUTOFF_ADJOINT:.5e} and therefore " + "we are not able to process this adjoint simulation in a broadband way. " + "To fix, split your simulation into a set of simulations, one for each port, and " + "run parallel, broadband simulations using 'web.run_async'. " + ) + + # construct the new adjoint sources with the corrected amplitudes + src_times_corrected = [ + self.set_amp(src_time=src_time, amp=amp).updated_copy(fwidth=self.fwidth_adj) + for src_time, amp in zip(source_times, amps_corrected) + ] + srcs_corrected = [] + for src_time in src_times_corrected: + src_new = src.updated_copy(source_time=src_time) + srcs_corrected.append(src_new) + + return srcs_corrected + def get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorDataType: """Grab the field or permittivity data for a given structure index.""" diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 79258b0a91..4b5fd6bb4e 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1150,7 +1150,7 @@ def derivative_eps_complex_volume( ) vjp_value += vjp_value_fld - return vjp_value + return vjp_value.sum("f") class AbstractCustomMedium(AbstractMedium, ABC): @@ -1395,7 +1395,7 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).sum("f") vjp_array = np.array(E_der_dim_interp.values).astype(complex) vjp_array = vjp_array.reshape(eps_data.shape) @@ -2551,8 +2551,11 @@ def _derivative_field_cmp( eps_data: PermittivityDataset, dim: str, ) -> np.ndarray: - coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1} - dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp} + """Compute derivative with respect to the ``dim`` components within the custom medium.""" + + coords_interp = {key: eps_data.coords[key] for key in "xyz"} + coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} + dims_sum = [dim for dim in "xyz" if dim not in coords_interp] # compute sizes along each of the interpolation dimensions sizes_list = [] @@ -2581,8 +2584,11 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum) - vjp_array = np.array(E_der_dim_interp.values).astype(complex) + E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real + E_der_dim_interp = E_der_dim_interp.sum("f") + + vjp_array = np.array(E_der_dim_interp.values).astype(float) + vjp_array = vjp_array.reshape(eps_data.shape) # multiply by volume elements (if possible, being defensive here..) diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index f93d08e40c..e1b732a2f0 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -3339,14 +3339,6 @@ def freqs_adjoint(self) -> list[float]: if isinstance(mnt, FreqMonitor): freqs.update(mnt.freqs) freqs = sorted(freqs) - - if len(freqs) > 1: - raise ValueError( - "Only the same, single frequency is supported in all monitors " - "when using autograd differentiation. " - f"Found {len(freqs)} distinct frequencies in the monitors." - ) - return freqs """ Accounting """ diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index ac7c281e8d..35f805797d 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -2,6 +2,7 @@ import traceback import typing +from collections import defaultdict import numpy as np from autograd.builtins import dict as dict_ag @@ -485,28 +486,13 @@ def _run_bwd( def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" - sim_adj = setup_adj( + sim_adj, post_norm = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) - # no adjoint sources, no gradient for you :( - if not len(sim_adj.sources): - td.log.warning( - "No adjoint sources generated. " - "There is likely zero output in the data, or you have no traceable monitors. " - "As a result, the 'SimulationData' returned has no contribution to the gradient. " - "Skipping the adjoint simulation. " - "If this is unexpected, please double check the post-processing function to ensure " - "there is a path from the 'SimulationData' to the objective function return value." - ) - - # TODO: add a test for this - # construct a VJP of all zeros for all tracers in the original simulation - return {path: 0 * value for path, value in sim_fields_original.items()} - # run adjoint simulation task_name_adj = str(task_name) + "_adjoint" sim_data_adj = _run_tidy3d(sim_adj, task_name=task_name_adj, **run_kwargs) @@ -516,6 +502,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm=post_norm, ) return vjp @@ -548,19 +535,21 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd task_names_adj = {task_name + "_adjoint" for task_name in task_names} sims_adj = {} + post_norm_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): data_fields_vjp = data_fields_dict_vjp[task_name] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_adj = setup_adj( + sim_adj, post_norm = setup_adj( data_fields_vjp=data_fields_vjp, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, ) sims_adj[task_name_adj] = sim_adj + post_norm_dict[task_name_adj] = post_norm # TODO: handle case where no adjoint sources? @@ -570,15 +559,16 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_vjp_dict = {} for task_name, task_name_adj in zip(task_names, task_names_adj): sim_data_adj = batch_data_adj[task_name_adj] + post_norm = post_norm_dict[task_name_adj] sim_data_orig = sim_data_orig_dict[task_name] sim_data_fwd = sim_data_fwd_dict[task_name] sim_fields_original = sim_fields_original_dict[task_name] - sim_fields_vjp = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_original=sim_fields_original, + post_norm=post_norm, ) sim_fields_vjp_dict[task_name] = sim_fields_vjp @@ -592,7 +582,7 @@ def setup_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, -) -> td.Simulation: +) -> tuple[td.Simulation, float]: """Construct an adjoint simulation from a set of data_fields for the VJP.""" td.log.info("Running custom vjp (adjoint) pipeline.") @@ -607,13 +597,13 @@ def setup_adj( # make adjoint simulation from that SimulationData data_vjp_paths = set(data_fields_vjp.keys()) - sim_adj = sim_data_vjp.make_adjoint_sim( + sim_adj, post_norm = sim_data_vjp.make_adjoint_sim( data_vjp_paths=data_vjp_paths, adjoint_monitors=sim_data_fwd.simulation.monitors ) td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.") - return sim_adj + return sim_adj, post_norm def postprocess_adj( @@ -621,17 +611,15 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_original: AutogradFieldMap, + post_norm: float, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" # map of index into 'structures' to the list of paths we need vjps for - sim_vjp_map = {} + sim_vjp_map = defaultdict(list) for _, structure_index, *structure_path in sim_fields_original.keys(): structure_path = tuple(structure_path) - if structure_index in sim_vjp_map: - sim_vjp_map[structure_index].append(structure_path) - else: - sim_vjp_map[structure_index] = [structure_path] + sim_vjp_map[structure_index].append(structure_path) # store the derivative values given the forward and adjoint data sim_fields_vjp = {} @@ -642,6 +630,14 @@ def postprocess_adj( fld_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="fld") eps_adj = sim_data_adj.get_adjoint_data(structure_index, data_type="eps") + # post normalize the adjoint fields if a single, broadband source + + fwd_flds_normed = {} + for key, val in fld_adj.field_components.items(): + fwd_flds_normed[key] = val * post_norm + + fld_adj = fld_adj.updated_copy(**fwd_flds_normed) + # maps of the E_fwd * E_adj and D_fwd * D_adj, each as as td.FieldData & 'Ex', 'Ey', 'Ez' der_maps = get_derivative_maps( fld_fwd=fld_fwd, eps_fwd=eps_fwd, fld_adj=fld_adj, eps_adj=eps_adj @@ -655,9 +651,7 @@ def postprocess_adj( # todo: handle multi-frequency, move to a property? frequencies = {src.source_time.freq0 for src in sim_data_adj.simulation.sources} frequencies = list(frequencies) - if len(frequencies) != 1: - raise RuntimeError("Multiple adjoint frequencies found.") - freq_adj = frequencies[0] + freq_adj = frequencies[0] or None eps_in = np.mean(structure.medium.eps_model(freq_adj)) eps_out = np.mean(sim_data_orig.simulation.medium.eps_model(freq_adj))