Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added facade functions to_zarr and from_zarr #2236

Merged
merged 6 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
- Add InferenceData<->DataTree conversion functions ([2253](https://github.com/arviz-devs/arviz/pull/2253))
- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))
- Added facade functions `az.to_zarr` and `az.from_zarr` ([2236](https://github.com/arviz-devs/arviz/pull/2236))

### Maintenance and fixes
- Replace deprecated np.product with np.prod ([2249](https://github.com/arviz-devs/arviz/pull/2249))
- Fix numba deprecation warning ([2246](https://github.com/arviz-devs/arviz/pull/2246))
- Fixes for creating numpy object array ([2233](https://github.com/arviz-devs/arviz/pull/2233) and [2239](https://github.com/arviz-devs/arviz/pull/2239))
- Adapt histograms generated by plot_dist to input dtype ([2247](https://github.com/arviz-devs/arviz/pull/2247))


### Deprecation

### Documentation
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .io_pyjags import from_pyjags
from .io_pyro import from_pyro
from .io_pystan import from_pystan
from .io_zarr import from_zarr, to_zarr
from .utils import extract, extract_dataset

__all__ = [
Expand Down Expand Up @@ -44,6 +45,8 @@
"to_datatree",
"to_json",
"to_netcdf",
"from_zarr",
"to_zarr",
"CoordSpec",
"DimSpec",
]
46 changes: 46 additions & 0 deletions arviz/data/io_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Input and output support for zarr data."""

from .converters import convert_to_inference_data
from .inference_data import InferenceData


def from_zarr(store):
return InferenceData.from_zarr(store)


from_zarr.__doc__ = InferenceData.from_zarr.__doc__


def to_zarr(data, store=None, **kwargs):
"""
Convert data to zarr, optionally saving to disk if ``store`` is provided.

The zarr storage is using the same group names as the InferenceData.

Parameters
----------
store : zarr.storage, MutableMapping or str, optional
Zarr storage class or path to desired DirectoryStore.
Default (None) a store is created in a temporary directory.
**kwargs : dict, optional
Passed to :py:func:`convert_to_inference_data`.

Returns
-------
zarr.hierarchy.group
A zarr hierarchy group containing the InferenceData.

Raises
------
TypeError
If no valid store is found.


References
----------
https://zarr.readthedocs.io/

"""
inference_data = convert_to_inference_data(data, **kwargs)
zarr_group = inference_data.to_zarr(store=store)
return zarr_group
39 changes: 39 additions & 0 deletions arviz/tests/base_tests/test_data_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from ... import InferenceData, from_dict
from ... import to_zarr, from_zarr

from ..helpers import ( # pylint: disable=unused-import
chains,
Expand Down Expand Up @@ -103,3 +104,41 @@ def test_io_method(self, data, eight_schools_params, store, fill_attrs):
assert inference_data2.attrs["test"] == 1
else:
assert "test" not in inference_data2.attrs

def test_io_function(self, data, eight_schools_params):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data,
eight_schools_params,
fill_attrs=True,
)
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
"posterior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats": ["eta", "theta", "mu", "tau"],
"prior": ["eta", "theta", "mu", "tau"],
"prior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
"observed_data": ["J", "y", "sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

assert inference_data.attrs["test"] == 1

# check filename does not exist and use to_zarr method
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
filepath = os.path.join(tmp_dir, "zarr")

to_zarr(inference_data, store=filepath)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0

inference_data2 = from_zarr(filepath)

# Everything in dict still available in inference_data2 ?
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails

assert inference_data2.attrs["test"] == 1
2 changes: 2 additions & 0 deletions doc/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ IO / General conversion
to_datatree
to_json
to_netcdf
from_zarr
to_zarr


General functions
Expand Down