Skip to content

Commit

Permalink
reorganized how sim data normalizatio works
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Apr 6, 2022
1 parent bc51fdf commit bcf15c2
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr
import numpy as np
import h5py
import pydantic as pd

from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis
from .base import Tidy3dBaseModel
Expand Down Expand Up @@ -959,14 +960,24 @@ class SimulationData(AbstractSimulationData):
A string containing the log information from the simulation run.
diverged : bool = False
A boolean flag denoting if the simulation run diverged.
normalized : bool = False
A boolean flag denoting whether the data has been normalized by the spectrum of a source.
"""

monitor_data: Dict[str, Tidy3dData]
log_string: str = None
diverged: bool = False
normalized: bool = False

# set internally by the normalize function
_normalize_index: pd.NonNegativeInt = pd.PrivateAttr(None)

@property
def normalized(self) -> bool:
"""Is this data normalized?"""
return self._normalize_index is not None

@property
def normalize_index(self) -> pd.NonNegativeInt:
"""What is the index of the source that normalized this data. If ``None``, unnormalized."""
return self._normalize_index

@property
def log(self) -> str:
Expand Down Expand Up @@ -1170,7 +1181,7 @@ def normalize(self, normalize_index: int = 0):
normalize_index : int = 0
If specified, normalizes the frequency-domain data by the amplitude spectrum of the
source corresponding to ``simulation.sources[normalize_index]``.
This occurs when the data is loaded into a :class:`SimulationData` object.
This occurs when the data is loaded into a :class:`.SimulationData` object.
Returns
-------
Expand All @@ -1180,7 +1191,8 @@ def normalize(self, normalize_index: int = 0):

if self.normalized:
raise DataError(
"This SimulationData object has already been normalized"
"This SimulationData object has already been normalized "
f"with `normalize_index` of {self._normalize_index} "
"and can't be normalized again."
)

Expand All @@ -1190,7 +1202,6 @@ def normalize(self, normalize_index: int = 0):
except IndexError as e:
raise DataError(f"Could not locate source at normalize_index={normalize_index}.") from e

source_time = source.source_time
sim_data_norm = self.copy(deep=True)
times = self.simulation.tmesh
dt = self.simulation.dt
Expand All @@ -1216,7 +1227,7 @@ def normalize_data(monitor_data):
else:
normalize_data(monitor_data)

sim_data_norm.normalized = True
sim_data_norm._normalize_index = normalize_index # pylint:disable=protected-access
return sim_data_norm

def to_file(self, fname: str) -> None:
Expand All @@ -1239,7 +1250,7 @@ def to_file(self, fname: str) -> None:

# save diverged and normalized flags as attributes
f_handle.attrs["diverged"] = self.diverged
f_handle.attrs["normalized"] = self.normalized
f_handle.attrs["normalize_index"] = self._normalize_index

# make a group for monitor_data
mon_data_grp = f_handle.create_group("monitor_data")
Expand Down Expand Up @@ -1283,9 +1294,7 @@ def from_file(
logging.warning("Simulation run has diverged!")

# get whether this data has been normalized
normalized = f_handle.attrs.get("normalized")
if normalized is None:
normalized = False
normalize_index_file = f_handle.attrs.get("normalize_index")

# loop through monitor dataset and create all MonitorData instances
monitor_data_dict = {}
Expand All @@ -1302,18 +1311,25 @@ def from_file(
monitor_data=monitor_data_dict,
log_string=log_string,
diverged=diverged,
normalized=normalized,
)

# handle normalization
if normalize_index is not None:
if sim_data.normalized:
raise DataError(
"Data from this file is already normalized. "
"Instead, load `.from_file()` with `normalize_index=None."
)
sim_data = sim_data.normalize(normalize_index=normalize_index)
# make sure to tag the SimulationData with the normalize_index stored from file
sim_data._normalize_index = normalize_index_file

# if the data in the file has not been normalized, normalize with supplied index.
if normalize_index_file is None:
return sim_data.normalize(normalize_index=normalize_index)

# if normalize_index supplied and different from one in the file, raise.
if normalize_index is not None and normalize_index != normalize_index_file:
raise DataError(
"Data from this file is already normalized with "
f"normalize_index={normalize_index_file}, can't normalize with supplied "
f"normalize_index={normalize_index} unless they are the same "
"or supplied normalize index is `None`."
)

# otherwise, just return the sim_data as it came from the file.
return sim_data

def __eq__(self, other):
Expand Down

0 comments on commit bcf15c2

Please sign in to comment.