Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CustomSourceTime with times completely outside envelope range #1901

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers.
- Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times.

## [2.7.2] - 2024-08-07

Expand Down
12 changes: 12 additions & 0 deletions tests/test_components/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,20 @@ def test_custom_source_time(log_capture):
atol=ATOL,
)

# all times out of range
_ = cst.amp_time([-1])
_ = cst.amp_time(-1)
assert np.allclose(cst.amp_time([2]), np.exp(-1j * 2 * np.pi * 2 * freq0), rtol=0, atol=ATOL)

assert_log_level(log_capture, None)

vals = td.components.data.data_array.TimeDataArray([1, 2], coords=dict(t=[-1, -0.5]))
dataset = td.components.data.dataset.TimeDataset(values=vals)
cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12)
source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex")
with AssertLogLevel(log_capture, "WARNING", contains_str="defined over a time range"):
sim = sim.updated_copy(sources=[source])

# test normalization warning
with AssertLogLevel(log_capture, "WARNING"):
sim = sim.updated_copy(normalize_index=0)
Expand Down
18 changes: 18 additions & 0 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,24 @@ def _post_init_validators(self) -> None:
self._validate_no_structures_pml()
self._validate_tfsf_nonuniform_grid()
self._validate_nonlinear_specs()
self._validate_custom_source_time()

def _validate_custom_source_time(self):
"""Warn if all simulation times are outside CustomSourceTime definition range."""
run_time = self._run_time
for idx, source in enumerate(self.sources):
if isinstance(source.source_time, CustomSourceTime):
if source.source_time._all_outside_range(run_time=run_time):
data_times = source.source_time.data_times
mint = np.min(data_times)
maxt = np.max(data_times)
log.warning(
f"'CustomSourceTime' at 'sources[{idx}]' is defined over a time range "
f"'({mint}, {maxt})' which does not include any of the 'Simulation' "
f"times '({0, run_time})'. The envelope will be constant extrapolated "
"from the first or last value in the 'CustomSourceTime', which may not "
"be the desired outcome."
)

def _validate_no_structures_pml(self) -> None:
"""Ensure no structures terminate / have bounds inside of PML."""
Expand Down
34 changes: 30 additions & 4 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,31 @@ def from_values(
source_time_dataset=source_time_dataset,
)

@property
def data_times(self) -> ArrayFloat1D:
"""Times of envelope definition."""
if self.source_time_dataset is None:
return []
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
return data_times

def _all_outside_range(self, run_time: float) -> bool:
"""Whether all times are outside range of definition."""

# can't validate if data isn't loaded
if self.source_time_dataset is None:
return False

# make time a numpy array for uniform handling
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
max_time_shifted = run_time - self.offset * twidth
min_time_shifted = -self.offset * twidth

return (max_time_shifted < min(data_times)) | (min_time_shifted > max(data_times))

def amp_time(self, time: float) -> complex:
"""Complex-valued source amplitude as a function of time.

Expand All @@ -370,8 +395,8 @@ def amp_time(self, time: float) -> complex:
return None

# make time a numpy array for uniform handling
times = np.array([time] if isinstance(time, float) else time)
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
times = np.array([time] if isinstance(time, (int, float)) else time)
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
Expand All @@ -384,12 +409,13 @@ def amp_time(self, time: float) -> complex:
envelope = np.zeros(len(time_shifted), dtype=complex)
values = self.source_time_dataset.values
envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy()
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()
if not all(mask):
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()

# modulation, phase, amplitude
omega0 = 2 * np.pi * self.freq0
offset = np.exp(1j * self.phase)
oscillation = np.exp(-1j * omega0 * time)
oscillation = np.exp(-1j * omega0 * times)
amp = self.amplitude

return offset * oscillation * amp * envelope
Expand Down
Loading