Skip to content

Commit

Permalink
yannick comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Jul 24, 2024
1 parent 78b0dc7 commit 83e5984
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 47 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Mode solver plugin now supports 'EMESimulation'.
- `TriangleMesh` class: automatic removal of zero-area faces, and functions `fill_holes` and `fix_winding` to attempt mesh repair.
- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation.
- Automatic differentiation with `autograd` supports multiple frequencies through single, broadband adjoint simulation when the objective depends on a single port.

### Fixed
Expand Down
79 changes: 46 additions & 33 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ...exceptions import DataError, FileError, Tidy3dKeyError
from ...log import log
from ..base import JSON_TAG
from ..base import JSON_TAG, Tidy3dBaseModel
from ..base_sim.data.sim_data import AbstractSimulationData
from ..file_util import replace_values
from ..monitor import Monitor
Expand All @@ -37,7 +37,31 @@
DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes}

# residuals below this are considered good fits for broadband adjoint source creation
RESIDUAL_CUTOFF_ADJOINT = 0.5
RESIDUAL_CUTOFF_ADJOINT = 1e-6


class AdjointSourceInfo(Tidy3dBaseModel):
"""Stores information about the adjoint sources to pass to autograd pipeline."""

sources: tuple[Source, ...] = pd.Field(
...,
title="Adjoint Sources",
description="Set of processed sources to include in the adjoint simulation.",
)

post_norm: Union[float, xr.DataArray] = pd.Field(
...,
title="Post Normalization Values",
description="Factor to multiply the adjoint fields by after running "
"given the adjoint source pipeline used.",
)

normalize_sim: bool = pd.Field(
...,
title="Normalize Adjoint Simulation",
description="Whether the adjoint simulation needs to be normalized "
"given the adjoint source pipeline used.",
)


class AbstractYeeGridSimulationData(AbstractSimulationData, ABC):
Expand Down Expand Up @@ -956,7 +980,7 @@ def source_spectrum_fn(freqs):

def make_adjoint_sim(
self, data_vjp_paths: set[tuple], adjoint_monitors: list[Monitor]
) -> tuple[Simulation, float]:
) -> tuple[Simulation, AdjointSourceInfo]:
"""Make the adjoint simulation from the original simulation and the VJP-containing data."""

sim_original = self.simulation
Expand All @@ -967,19 +991,19 @@ def make_adjoint_sim(
for src_list in sources_adj_dict.values():
adj_srcs += list(src_list)

sources_adj, post_norm, norm_source = self.process_adjoint_sources(adj_srcs=adj_srcs)
adjoint_source_info = self.process_adjoint_sources(adj_srcs=adj_srcs)

# grab boundary conditions with flipped Bloch vectors (for adjoint)
bc_adj = sim_original.boundary_spec.flipped_bloch_vecs

# fields to update the 'fwd' simulation with to make it 'adj'
sim_adj_update_dict = dict(
sources=sources_adj,
sources=adjoint_source_info.sources,
boundary_spec=bc_adj,
monitors=adjoint_monitors,
)

if not norm_source:
if not adjoint_source_info.normalize_sim:
sim_adj_update_dict["normalize_index"] = None

# set the ADJ grid spec wavelength to the original wavelength (for same meshing)
Expand All @@ -989,7 +1013,7 @@ def make_adjoint_sim(
grid_spec_adj = grid_spec_original.updated_copy(wavelength=wavelength_original)
sim_adj_update_dict["grid_spec"] = grid_spec_adj

return sim_original.updated_copy(**sim_adj_update_dict), post_norm
return sim_original.updated_copy(**sim_adj_update_dict), adjoint_source_info

def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]:
"""Generate all of the non-zero sources for the adjoint simulation given the VJP data."""
Expand All @@ -1010,29 +1034,14 @@ def make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, Source]:

return sources_adj_all

@staticmethod
def get_amp(src_time: GaussianPulse) -> complex:
"""grab the complex amplitude from a ``SourceTime``."""
mag = src_time.amplitude
phase = np.exp(1j * src_time.phase)
return mag * phase

@staticmethod
def set_amp(src_time: GaussianPulse, amp: complex) -> GaussianPulse:
"""set the complex amplitude of a ``SourceTime``."""
amplitude = abs(amp)
phase = np.angle(amp)
return src_time.updated_copy(amplitude=amplitude, phase=phase)

@property
def fwidth_adj(self) -> float:
# fwidth of forward pass, try as default for adjoint
normalize_index_fwd = self.simulation.normalize_index or 0
return self.simulation.sources[normalize_index_fwd].source_time.fwidth

def process_adjoint_sources(
self, adj_srcs: list[Source]
) -> tuple[list[Source], Union[float, xr.DataArray], bool]:
def process_adjoint_sources(self, adj_srcs: list[Source]) -> AdjointSourceInfo:
# tuple[list[Source], Union[float, xr.DataArray], bool]:
"""Compute list of final sources along with a post run normalization for adj fields."""

# dictionary mapping unique spatial dependence of each Source to list of time-dependencies
Expand All @@ -1049,12 +1058,14 @@ def process_adjoint_sources(
# next, figure out which treatment / normalization to apply
if num_unique_freqs == 1:
log.info("adjoint source creation: one unique frequency, no normalization")
return adj_srcs, 1.0, True
return AdjointSourceInfo(sources=adj_srcs, post_norm=1.0, normalize_sim=True)
# return adj_srcs, 1.0, True

if num_ports == 1 and len(adj_srcs) == num_unique_freqs:
log.info("adjoint source creation: one spatial port detected")
adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs)
return adj_srcs, post_norm, True
return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True)
# return adj_srcs, post_norm, True

# if several spatial ports and several frequencies, try to fit
log.info("adjoint source creation: trying multifrequency fit.")
Expand All @@ -1063,7 +1074,8 @@ def process_adjoint_sources(
spatial_to_src_times=spatial_to_src_times,
json_to_sources=json_to_sources,
)
return adj_srcs, post_norm, False
return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=False)
# return adj_srcs, post_norm, False

""" SIMPLE APPROACH """

Expand All @@ -1084,7 +1096,7 @@ def process_adjoint_sources_broadband(

return [src_broadband], post_norm_amps

def _make_broadband_source(self, adj_srcs: list[Source], num_fwidth: float = 0.5) -> Source:
def _make_broadband_source(self, adj_srcs: list[Source]) -> Source:
"""Make a broadband source for a set of adjoint sources."""

source_index = self.simulation.normalize_index or 0
Expand Down Expand Up @@ -1139,15 +1151,16 @@ def process_adjoint_sources_fit(
# compute amplitudes of each adjoint source, and the norm
adj_src_amps = []
for src in new_adj_srcs:
amp = self.get_amp(src.source_time)
amp = src.source_time.amp_complex
adj_src_amps.append(amp)
norm_amps = np.linalg.norm(adj_src_amps)

# normalize all of the adjoint sources by this and return the normalization term used
adj_srcs_norm = []
for src in new_adj_srcs:
amp = self.get_amp(src.source_time)
src_time_norm = self.set_amp(src_time=src.source_time, amp=amp / norm_amps)
src_time = src.source_time
amp = src_time.amp_complex
src_time_norm = src_time.from_amp_complex(amp=amp / norm_amps)
src_nrm = src.updated_copy(source_time=src_time_norm)
adj_srcs_norm.append(src_nrm)

Expand Down Expand Up @@ -1180,7 +1193,7 @@ def get_coupling_matrix(fwidth: float) -> np.ndarray:
]
).T

amps_adj = np.array([self.get_amp(src_time) for src_time in source_times])
amps_adj = np.array([src_time.amp_complex for src_time in source_times])

# compute the corrected set of amps to inject at each freq to take coupling into account
def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]:
Expand All @@ -1207,7 +1220,7 @@ def get_amps_corrected(fwidth: float) -> tuple[np.ndarray, float]:

# construct the new adjoint sources with the corrected amplitudes
src_times_corrected = [
self.set_amp(src_time=src_time, amp=amp).updated_copy(fwidth=self.fwidth_adj)
src_time.from_amp_complex(amp=amp, fwidth=self.fwidth_adj)
for src_time, amp in zip(source_times, amps_corrected)
]
srcs_corrected = []
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,7 +2587,7 @@ def _derivative_field_cmp(
E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real
E_der_dim_interp = E_der_dim_interp.sum("f")

vjp_array = np.array(E_der_dim_interp.values).astype(float)
vjp_array = np.array(E_der_dim_interp.values, dtype=float)

vjp_array = vjp_array.reshape(eps_data.shape)

Expand Down
14 changes: 14 additions & 0 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ def end_time(self) -> float | None:

return self.offset * self.twidth + END_TIME_FACTOR_GAUSSIAN * self.twidth

@property
def amp_complex(self) -> complex:
"""grab the complex amplitude from a ``SourceTime``."""
mag = self.amplitude
phase = np.exp(1j * self.phase)
return mag * phase

@classmethod
def from_amp_complex(cls, amp: complex, **kwargs) -> GaussianPulse:
"""set the complex amplitude of a ``SourceTime``."""
amplitude = abs(amp)
phase = np.angle(amp)
return cls(amplitude=amplitude, phase=phase, **kwargs)


class ContinuousWave(Pulse):
"""Source time dependence that ramps up to continuous oscillation
Expand Down
25 changes: 13 additions & 12 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tidy3d as td
from tidy3d.components.autograd import AutogradFieldMap, get_static
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
from tidy3d.components.data.sim_data import AdjointSourceInfo

from ..asynchronous import DEFAULT_DATA_DIR
from ..asynchronous import run_async as run_async_webapi
Expand Down Expand Up @@ -486,7 +487,7 @@ def _run_bwd(
def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
"""dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}"""

sim_adj, post_norm = setup_adj(
sim_adj, adjoint_source_info = setup_adj(
data_fields_vjp=data_fields_vjp,
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
Expand All @@ -502,7 +503,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap:
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_original=sim_fields_original,
post_norm=post_norm,
adjoint_source_info=adjoint_source_info,
)

return vjp
Expand Down Expand Up @@ -535,21 +536,21 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
task_names_adj = {task_name + "_adjoint" for task_name in task_names}

sims_adj = {}
post_norm_dict = {}
adjoint_source_info_dict = {}
for task_name, task_name_adj in zip(task_names, task_names_adj):
data_fields_vjp = data_fields_dict_vjp[task_name]
sim_data_orig = sim_data_orig_dict[task_name]
sim_data_fwd = sim_data_fwd_dict[task_name]
sim_fields_original = sim_fields_original_dict[task_name]

sim_adj, post_norm = setup_adj(
sim_adj, adjoint_source_info = setup_adj(
data_fields_vjp=data_fields_vjp,
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_original=sim_fields_original,
)
sims_adj[task_name_adj] = sim_adj
post_norm_dict[task_name_adj] = post_norm
adjoint_source_info_dict[task_name_adj] = adjoint_source_info

# TODO: handle case where no adjoint sources?

Expand All @@ -559,7 +560,7 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
sim_fields_vjp_dict = {}
for task_name, task_name_adj in zip(task_names, task_names_adj):
sim_data_adj = batch_data_adj[task_name_adj]
post_norm = post_norm_dict[task_name_adj]
adjoint_source_info = adjoint_source_info_dict[task_name_adj]
sim_data_orig = sim_data_orig_dict[task_name]
sim_data_fwd = sim_data_fwd_dict[task_name]
sim_fields_original = sim_fields_original_dict[task_name]
Expand All @@ -568,7 +569,7 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd
sim_data_orig=sim_data_orig,
sim_data_fwd=sim_data_fwd,
sim_fields_original=sim_fields_original,
post_norm=post_norm,
adjoint_source_info=adjoint_source_info,
)
sim_fields_vjp_dict[task_name] = sim_fields_vjp

Expand All @@ -582,7 +583,7 @@ def setup_adj(
sim_data_orig: td.SimulationData,
sim_data_fwd: td.SimulationData,
sim_fields_original: AutogradFieldMap,
) -> tuple[td.Simulation, float]:
) -> tuple[td.Simulation, AdjointSourceInfo]:
"""Construct an adjoint simulation from a set of data_fields for the VJP."""

td.log.info("Running custom vjp (adjoint) pipeline.")
Expand All @@ -597,21 +598,21 @@ def setup_adj(

# make adjoint simulation from that SimulationData
data_vjp_paths = set(data_fields_vjp.keys())
sim_adj, post_norm = sim_data_vjp.make_adjoint_sim(
sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim(
data_vjp_paths=data_vjp_paths, adjoint_monitors=sim_data_fwd.simulation.monitors
)

td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.")

return sim_adj, post_norm
return sim_adj, adjoint_source_info


def postprocess_adj(
sim_data_adj: td.SimulationData,
sim_data_orig: td.SimulationData,
sim_data_fwd: td.SimulationData,
sim_fields_original: AutogradFieldMap,
post_norm: float,
adjoint_source_info: AdjointSourceInfo,
) -> AutogradFieldMap:
"""Postprocess some data from the adjoint simulation into the VJP for the original sim flds."""

Expand All @@ -634,7 +635,7 @@ def postprocess_adj(

fwd_flds_normed = {}
for key, val in fld_adj.field_components.items():
fwd_flds_normed[key] = val * post_norm
fwd_flds_normed[key] = val * adjoint_source_info.post_norm

fld_adj = fld_adj.updated_copy(**fwd_flds_normed)

Expand Down

0 comments on commit 83e5984

Please sign in to comment.