Skip to content

Commit

Permalink
use hash instead of json in adjoint sources
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Jul 29, 2024
1 parent b67435f commit 93e3b31
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pydantic.v1 as pd
import xarray as xr

from ...constants import C_0
from ...exceptions import DataError, FileError, Tidy3dKeyError
from ...log import log
from ..base import JSON_TAG, Tidy3dBaseModel
Expand Down Expand Up @@ -1044,15 +1045,17 @@ def process_adjoint_sources(self, adj_srcs: list[Source]) -> AdjointSourceInfo:
# tuple[list[Source], Union[float, xr.DataArray], bool]:
"""Compute list of final sources along with a post run normalization for adj fields."""

# dictionary mapping unique spatial dependence of each Source to list of time-dependencies
json_to_sources = defaultdict(None)
spatial_to_src_times = defaultdict(list)
# dictionary mapping hash of sources with same freq dependence to list of time-dependencies
hashes_to_sources = defaultdict(None)
hashes_to_src_times = defaultdict(list)
tmp_src_time = GaussianPulse(freq0=C_0, fwidth=self.fwidth_adj)
for src in adj_srcs:
src_spatial_json = src.json(exclude={"source_time"})
json_to_sources[src_spatial_json] = src
spatial_to_src_times[src_spatial_json].append(src.source_time)
tmp_src = src.updated_copy(source_time=tmp_src_time)
tmp_src_hash = tmp_src._hash_self()
hashes_to_sources[tmp_src_hash] = src
hashes_to_src_times[tmp_src_hash].append(src.source_time)

num_ports = len(spatial_to_src_times)
num_ports = len(hashes_to_src_times)
num_unique_freqs = len({src.source_time.freq0 for src in adj_srcs})

# next, figure out which treatment / normalization to apply
Expand All @@ -1069,8 +1072,8 @@ def process_adjoint_sources(self, adj_srcs: list[Source]) -> AdjointSourceInfo:
log.info("Adjoint source creation: trying multifrequency fit.")
adj_srcs, post_norm = self.process_adjoint_sources_fit(
adj_srcs=adj_srcs,
spatial_to_src_times=spatial_to_src_times,
json_to_sources=json_to_sources,
hashes_to_src_times=hashes_to_src_times,
hashes_to_sources=hashes_to_sources,
)
return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=False)

Expand Down Expand Up @@ -1123,8 +1126,8 @@ def _make_post_norm_amps(adj_srcs: list[Source]) -> xr.DataArray:
def process_adjoint_sources_fit(
self,
adj_srcs: list[Source],
spatial_to_src_times: dict[str, GaussianPulse],
json_to_sources: dict[str, list[Source]],
hashes_to_src_times: dict[int, GaussianPulse],
hashes_to_sources: dict[int, list[Source]],
) -> tuple[list[Source], float]:
"""Process the adjoint sources using a least squared fit to the derivative data."""

Expand All @@ -1138,8 +1141,8 @@ def process_adjoint_sources_fit(

# new adjoint sources
new_adj_srcs = []
for src_json, source_times in spatial_to_src_times.items():
src = json_to_sources[src_json]
for src_hash, source_times in hashes_to_src_times.items():
src = hashes_to_sources[src_hash]
new_sources = self.correct_adjoint_sources(
src=src, fwidth=self.fwidth_adj, source_times=source_times
)
Expand Down

0 comments on commit 93e3b31

Please sign in to comment.