diff --git a/poetry.lock b/poetry.lock index 98bbd44..0604d57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4010,4 +4010,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0e05d7166a1cd987614bc44402561c5fdf5497ee8ea083c61cff08410a5d68e2" +content-hash = "ffff04f4488a8896218f2abba066c8c6a880a95dd1151f23dacaaa1c6bcbd45c" diff --git a/pyproject.toml b/pyproject.toml index a00d04b..0117976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ flake8 = "^7.0.0" black = "^24.1.1" pyarrow = "^15.0.0" getdist = "^1.4.7" +h5py = "^3.10.0" [tool.poetry.dev-dependencies] pytest = "^7.3.2" diff --git a/src/scripts/io.py b/src/scripts/io.py index 68c7a28..9cb70f4 100644 --- a/src/scripts/io.py +++ b/src/scripts/io.py @@ -1,5 +1,7 @@ import pickle import h5py +import numpy as np +import torch class ModelLoader: def save_model_pkl(self, path, model_name, posterior): @@ -82,13 +84,15 @@ def save_data_h5(self, :param data_name: Name of the data :param data: Data to be saved """ + data_arrays = {key: np.asarray(value) for key, value in data.items()} + file_name = path + data_name + ".h5" with h5py.File(file_name, "w") as file: - file.create_dataset(data_name, data=data) + # Save each array as a dataset in the HDF5 file + for key, value in data_arrays.items(): + file.create_dataset(key, data=value) - def load_data_h5(self, - data_name, - path='../saveddata/'): + def load_data_h5(self, data_name, path='../saveddata/'): """ Load data from an h5 file. @@ -97,6 +101,8 @@ def load_data_h5(self, :return: Loaded data """ file_name = path + data_name + ".h5" + loaded_data = {} with h5py.File(file_name, "r") as file: - data = file[data_name][...] - return data + for key in file.keys(): + loaded_data[key] = torch.Tensor(file[key][...]) + return loaded_data \ No newline at end of file