Skip to content

Commit

Permalink
adding h5
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Feb 6, 2024
1 parent 0119dbc commit c47ccc8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ flake8 = "^7.0.0"
black = "^24.1.1"
pyarrow = "^15.0.0"
getdist = "^1.4.7"
h5py = "^3.10.0"

[tool.poetry.dev-dependencies]
pytest = "^7.3.2"
Expand Down
18 changes: 12 additions & 6 deletions src/scripts/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
import h5py
import numpy as np
import torch

class ModelLoader:
def save_model_pkl(self, path, model_name, posterior):
Expand Down Expand Up @@ -82,13 +84,15 @@ def save_data_h5(self,
:param data_name: Name of the data
:param data: Data to be saved
"""
data_arrays = {key: np.asarray(value) for key, value in data.items()}

file_name = path + data_name + ".h5"
with h5py.File(file_name, "w") as file:
file.create_dataset(data_name, data=data)
# Save each array as a dataset in the HDF5 file
for key, value in data_arrays.items():
file.create_dataset(key, data=value)

def load_data_h5(self,
data_name,
path='../saveddata/'):
def load_data_h5(self, data_name, path='../saveddata/'):
"""
Load data from an h5 file.
Expand All @@ -97,6 +101,8 @@ def load_data_h5(self,
:return: Loaded data
"""
file_name = path + data_name + ".h5"
loaded_data = {}
with h5py.File(file_name, "r") as file:
data = file[data_name][...]
return data
for key in file.keys():
loaded_data[key] = torch.Tensor(file[key][...])
return loaded_data

0 comments on commit c47ccc8

Please sign in to comment.