Skip to content

Commit

Permalink
add datatree converters (#2253)
Browse files Browse the repository at this point in the history
* add datatree converters

* fix to_datatree

* pylint

* fix test

* pylint
  • Loading branch information
OriolAbril authored Jul 6, 2023
1 parent 4ead577 commit c6da92e
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## v0.x.x (TBD)

### New features

- Add InferenceData<->DataTree conversion functions ([2253](https://github.com/arviz-devs/arviz/pull/2253))
- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))

Expand Down
6 changes: 5 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .io_beanmachine import from_beanmachine
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_emcee import from_emcee
from .io_json import from_json
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
from .io_numpyro import from_numpyro
from .io_pyjags import from_pyjags
Expand All @@ -34,11 +35,14 @@
"from_emcee",
"from_cmdstan",
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_json",
"from_pyro",
"from_numpyro",
"from_netcdf",
"to_datatree",
"to_json",
"to_netcdf",
"CoordSpec",
"DimSpec",
Expand Down
1 change: 0 additions & 1 deletion arviz/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def list_datasets():
"""Get a string representation of all available datasets with descriptions."""
lines = []
for name, resource in itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()):

if isinstance(resource, LocalFileMetadata):
location = f"local: {resource.filename}"
elif isinstance(resource, RemoteFileMetadata):
Expand Down
20 changes: 20 additions & 0 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,26 @@ def to_netcdf(
empty_netcdf_file.close()
return filename

def to_datatree(self):
"""Convert InferenceData object to a :class:`~datatree.DataTree`."""
try:
from datatree import DataTree
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"datatree must be installed in order to use InferenceData.to_datatree"
) from err
return DataTree.from_dict({group: ds for group, ds in self.items()})

@staticmethod
def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
Parameters
----------
datatree : DataTree
"""
return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})

def to_dict(self, groups=None, filter_groups=None):
"""Convert InferenceData to a dictionary following xarray naming conventions.
Expand Down
22 changes: 22 additions & 0 deletions arviz/data/io_datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Conversion between InferenceData and DataTree."""
from .inference_data import InferenceData


def to_datatree(data):
"""Convert InferenceData object to a :class:`~datatree.DataTree`.
Parameters
----------
data : InferenceData
"""
return data.to_datatree()


def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
Parameters
----------
datatree : DataTree
"""
return InferenceData.from_datatree(datatree)
17 changes: 17 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
# pylint: disable=too-many-lines

import importlib
import os
from collections import namedtuple
from copy import deepcopy
Expand All @@ -21,6 +22,7 @@
concat,
convert_to_dataset,
convert_to_inference_data,
from_datatree,
from_dict,
from_json,
from_netcdf,
Expand All @@ -40,6 +42,7 @@
draws,
eight_schools_params,
models,
running_on_ci,
)


Expand Down Expand Up @@ -1383,6 +1386,20 @@ def test_json_converters(self, models):
assert not os.path.exists(filepath)


@pytest.mark.skipif(
not (importlib.util.find_spec("datatree") or running_on_ci()),
reason="test requires xarray-datatree library",
)
class TestDataTree:
def test_datatree(self):
idata = load_arviz_data("centered_eight")
dt = idata.to_datatree()
idata_back = from_datatree(dt)
for group, ds in idata.items():
assert_identical(ds, idata_back[group])
assert all(group in dt.children for group in idata.groups())


class TestConversions:
def test_id_conversion_idempotent(self):
stored = load_arviz_data("centered_eight")
Expand Down
3 changes: 3 additions & 0 deletions doc/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ IO / General conversion
convert_to_inference_data
convert_to_dataset
dict_to_dataset
from_datatree
from_dict
from_json
from_netcdf
to_datatree
to_json
to_netcdf


Expand Down
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ contourpy
ujson
dask[distributed]
zarr>=2.5.0
xarray-datatree

0 comments on commit c6da92e

Please sign in to comment.