Skip to content

Commit

Permalink
explicit flag to keep data as numpy to speed up data IO
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex authored and momchil-flex committed Aug 5, 2022
1 parent 567d864 commit 2ae68d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
33 changes: 22 additions & 11 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,17 @@ def to_hdf5(self, fname: str) -> None:
"""

@staticmethod
def unpack_dataset(dataset: h5py.Dataset) -> Any: # pylint:disable=too-many-return-statements
def unpack_dataset( # pylint:disable=too-many-return-statements
dataset: h5py.Dataset, keep_numpy: bool = False
) -> Any:
"""Gets the value contained in a dataset in a form ready to insert into final dict.
Parameters
----------
item : h5py.Dataset
The raw value coming from the dataset, which needs to be decoded.
keep_numpy : bool = False
Whether to load a ``np.ndarray`` as such or convert it to list.
Returns
-------
Expand All @@ -320,10 +324,9 @@ def unpack_dataset(dataset: h5py.Dataset) -> Any: # pylint:disable=too-many-ret
return [val.decode("utf-8") for val in value]
if value.dtype == bool:
return value.astype(bool)
# handle xarray datasets implicitly (retain np.ndarray type)
if len(value.shape) >= 4:
return value
return value.tolist()
if not keep_numpy:
return value.tolist()
return value

# decoding special types
if isinstance(value, np.bool_):
Expand Down Expand Up @@ -352,18 +355,28 @@ def load_from_handle(cls, hdf5_group: h5py.Group, **kwargs) -> Tidy3dBaseModel:
return cls.parse_obj(data_dict, **kwargs)

@classmethod
def _load_group_data(cls, data_dict: dict, hdf5_group: h5py.Group) -> dict:
def _load_group_data(
cls, data_dict: dict, hdf5_group: h5py.Group, keep_numpy: bool = False
) -> dict:
"""Recusively load the data from the group with dataset unpacking as base case."""

if "keep_numpy" in hdf5_group:
keep_numpy = hdf5_group["keep_numpy"]

for key, value in hdf5_group.items():

if key == "keep_numpy":
continue

# recurive case, try to load the group into data_dict[key]
if isinstance(value, h5py.Group):
data_dict[key] = cls._load_group_data(data_dict={}, hdf5_group=value)
data_dict[key] = cls._load_group_data(
data_dict={}, hdf5_group=value, keep_numpy=keep_numpy
)

# base case, unpack the value in the dataset
elif isinstance(value, h5py.Dataset):
data_dict[key] = cls.unpack_dataset(value)
data_dict[key] = cls.unpack_dataset(value, keep_numpy=keep_numpy)

if any("TUPLE_ELEMENT_" in key for key in data_dict.keys()):
return tuple(data_dict.values())
Expand Down Expand Up @@ -398,8 +411,6 @@ def pack_dataset(hdf5_group: h5py.Group, key: str, value: Any) -> None:
return
if isinstance(value, str):
value = value.encode("utf-8")
elif isinstance(value, bool):
value = np.array(value)

# numpy array containing strings (usually direction=['-','+'])
elif isinstance(value, np.ndarray) and (value.dtype == "<U1"):
Expand Down Expand Up @@ -429,7 +440,7 @@ def _save_group_data(self, data_dict: dict, hdf5_group: h5py.Group) -> None:

if isinstance(value, xr.DataArray):
coords = {key: np.array(val) for key, val in value.coords.items()}
value = dict(data=value.data, coords=coords)
value = dict(data=value.data, coords=coords, keep_numpy=True)

# if a tuple of dicts, convert to a dict with special
elif isinstance(value, tuple) and any(isinstance(val, dict) for val in value):
Expand Down
9 changes: 7 additions & 2 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ def validate(cls, value):
if isinstance(value, dict):
data = value.get("data")
coords = value.get("coords")
coords = {name: np.array(val) for name, val in coords.items()}
return cls(np.array(data), coords=coords)

# convert to numpy if not already
coords = {k: v if isinstance(v, np.ndarray) else np.array(v) for k, v in coords.items()}
if not isinstance(data, np.ndarray):
data = np.array(data)

return cls(data, coords=coords)

return cls(value)

Expand Down

0 comments on commit 2ae68d4

Please sign in to comment.