Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalize sim data upon .from_file() #287

Merged
merged 3 commits into from
Apr 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
"""Classes for Storing Monitor and Simulation Data."""

from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Union, Optional
import logging

import xarray as xr
import numpy as np
import h5py
import pydantic as pd

from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis
from .base import Tidy3dBaseModel
Expand Down Expand Up @@ -959,14 +960,24 @@ class SimulationData(AbstractSimulationData):
A string containing the log information from the simulation run.
diverged : bool = False
A boolean flag denoting if the simulation run diverged.
normalized : bool = False
A boolean flag denoting whether the data has been normalized by the spectrum of a source.
"""

monitor_data: Dict[str, Tidy3dData]
log_string: str = None
diverged: bool = False
normalized: bool = False

# set internally by the normalize function
_normalize_index: pd.NonNegativeInt = pd.PrivateAttr(None)

@property
def normalized(self) -> bool:
"""Is this data normalized?"""
return self._normalize_index is not None

@property
def normalize_index(self) -> pd.NonNegativeInt:
"""What is the index of the source that normalized this data. If ``None``, unnormalized."""
return self._normalize_index

@property
def log(self) -> str:
Expand Down Expand Up @@ -1170,7 +1181,7 @@ def normalize(self, normalize_index: int = 0):
normalize_index : int = 0
If specified, normalizes the frequency-domain data by the amplitude spectrum of the
source corresponding to ``simulation.sources[normalize_index]``.
This occurs when the data is loaded into a :class:`SimulationData` object.
This occurs when the data is loaded into a :class:`.SimulationData` object.

Returns
-------
Expand All @@ -1180,7 +1191,8 @@ def normalize(self, normalize_index: int = 0):

if self.normalized:
raise DataError(
"This SimulationData object has already been normalized"
"This SimulationData object has already been normalized "
f"with `normalize_index` of {self._normalize_index} "
"and can't be normalized again."
)

Expand All @@ -1190,7 +1202,6 @@ def normalize(self, normalize_index: int = 0):
except IndexError as e:
raise DataError(f"Could not locate source at normalize_index={normalize_index}.") from e

source_time = source.source_time
sim_data_norm = self.copy(deep=True)
times = self.simulation.tmesh
dt = self.simulation.dt
Expand All @@ -1216,7 +1227,7 @@ def normalize_data(monitor_data):
else:
normalize_data(monitor_data)

sim_data_norm.normalized = True
sim_data_norm._normalize_index = normalize_index # pylint:disable=protected-access
return sim_data_norm

def to_file(self, fname: str) -> None:
Expand All @@ -1237,8 +1248,9 @@ def to_file(self, fname: str) -> None:
if self.log_string:
Tidy3dData.save_string(f_handle, "log_string", self.log_string)

# save diverged flag as an attribute
# save diverged and normalized flags as attributes
f_handle.attrs["diverged"] = self.diverged
f_handle.attrs["normalize_index"] = self._normalize_index

# make a group for monitor_data
mon_data_grp = f_handle.create_group("monitor_data")
Expand All @@ -1249,7 +1261,9 @@ def to_file(self, fname: str) -> None:
mon_data.add_to_group(mon_grp)

@classmethod
def from_file(cls, fname: str):
def from_file(
cls, fname: str, normalize_index: Optional[int] = 0
): # pylint:disable=arguments-differ
"""Load :class:`SimulationData` from .hdf5 file.

Parameters
Expand Down Expand Up @@ -1279,6 +1293,9 @@ def from_file(cls, fname: str):
if diverged:
logging.warning("Simulation run has diverged!")

# get whether this data has been normalized
normalize_index_file = f_handle.attrs.get("normalize_index")

# loop through monitor dataset and create all MonitorData instances
monitor_data_dict = {}
for monitor_name, monitor_data in f_handle["monitor_data"].items():
Expand All @@ -1288,13 +1305,31 @@ def from_file(cls, fname: str):
monitor_data_instance = _data_type.load_from_group(monitor_data)
monitor_data_dict[monitor_name] = monitor_data_instance

# create a SimulationData object
sim_data = cls(
simulation=simulation,
monitor_data=monitor_data_dict,
log_string=log_string,
diverged=diverged,
)

# make sure to tag the SimulationData with the normalize_index stored from file
sim_data._normalize_index = normalize_index_file

# if the data in the file has not been normalized, normalize with supplied index.
if normalize_index_file is None:
return sim_data.normalize(normalize_index=normalize_index)

# if normalize_index supplied and different from one in the file, raise.
if normalize_index is not None and normalize_index != normalize_index_file:
raise DataError(
"Data from this file is already normalized with "
f"normalize_index={normalize_index_file}, can't normalize with supplied "
f"normalize_index={normalize_index} unless they are the same "
"or supplied normalize index is `None`."
)

# otherwise, just return the sim_data as it came from the file.
return sim_data

def __eq__(self, other):
Expand Down