Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_dataframe to InferenceData #1395

Merged
merged 15 commits into from
Sep 28, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## v0.x.x Unreleased
### New features
* Added `to_dataframe` method to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395))
* Added `__getitem__` magic to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395))

### Maintenance and fixes

Expand Down
2 changes: 1 addition & 1 deletion arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _make_json_serializable(data: dict) -> dict:
for key, value in data.items():
try:
json.dumps(value)
except TypeError:
except (TypeError, OverflowError):
pass
else:
ret[key] = value
Expand Down
145 changes: 142 additions & 3 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def __iter__(self):
for group in self._groups_all:
yield group

def __getitem__(self, key):
"""Get item by key."""
if key not in self._groups_all:
raise KeyError(key)
return getattr(self, key)
Comment on lines +193 to +197
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


def groups(self):
"""Return all groups present in InferenceData object."""
return self._groups_all
Expand Down Expand Up @@ -278,7 +284,13 @@ def to_dict(self, groups=None, filter_groups=None):
Parameters
----------
groups : list, optional
Write only these groups to netcdf file.
Groups where the transformation is to be applied. Can either be group names
or metagroup names.
filter_groups: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret groups as the real group or metagroup names.
If "like", interpret groups as substrings of the real group or metagroup names.
If "regex", interpret groups as regular expressions on the real group or
metagroup names. A la `pandas.filter`.

Returns
-------
Expand Down Expand Up @@ -328,13 +340,21 @@ def to_dict(self, groups=None, filter_groups=None):
ret["attrs"] = attrs
return ret

def to_json(self, filename, **kwargs):
def to_json(self, filename, groups=None, filter_groups=None, **kwargs):
"""Write InferenceData to a json file.

Parameters
----------
filename : str
Location to write to
groups : list, optional
Groups where the transformation is to be applied. Can either be group names
or metagroup names.
filter_groups: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret groups as the real group or metagroup names.
If "like", interpret groups as substrings of the real group or metagroup names.
If "regex", interpret groups as regular expressions on the real group or
metagroup names. A la `pandas.filter`.
kwargs : dict
kwargs passed to json.dump()

Expand All @@ -343,13 +363,132 @@ def to_json(self, filename, **kwargs):
str
Location of json file
"""
idata_dict = _make_json_serializable(self.to_dict())
idata_dict = _make_json_serializable(
self.to_dict(groups=groups, filter_groups=filter_groups)
)

with open(filename, "w") as file:
json.dump(idata_dict, file, **kwargs)

return filename

def to_dataframe(
self,
groups=None,
filter_groups=None,
include_coords=True,
include_index=True,
index_origin=None,
):
"""Convert InferenceData to a pandas DataFrame following xarray naming conventions.

This returns dataframe in a "wide" -format, where each item in ndimensional array is
unpacked. To access "tidy" -format, use xarray functionality found for each dataset.

In case of a multiple groups, function adds a group identification to the var name.

Data groups ("observed_data", "constant_data", "predictions_constant_data") are
skipped implicitly.

Raises TypeError if no valid groups are found.

Parameters
----------
groups: str or list of str, optional
Groups where the transformation is to be applied. Can either be group names
or metagroup names.
filter_groups: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret groups as the real group or metagroup names.
If "like", interpret groups as substrings of the real group or metagroup names.
If "regex", interpret groups as regular expressions on the real group or
metagroup names. A la `pandas.filter`.
include_coords: bool
Add coordinate values to column name (tuple).
include_index: bool
Add index information for multidimensional arrays.
index_origin: {0, 1}, optional
Starting index for multidimensional objects. 0- or 1-based.
Defaults to rcParams["data.index_origin"].

Returns
-------
pandas.DataFrame
A pandas DataFrame containing all selected groups of InferenceData object.
"""
# pylint: disable=too-many-nested-blocks
if not include_coords and not include_index:
raise TypeError("Both include_coords and include_index can not be False.")
if index_origin is None:
index_origin = rcParams["data.index_origin"]
if index_origin not in [0, 1]:
raise TypeError("index_origin must be 0 or 1, saw {}".format(index_origin))
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved

group_names = list(
filter(lambda x: "data" not in x, self._group_names(groups, filter_groups))
)

if not group_names:
raise TypeError("No valid groups found: {}".format(groups))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic ok?

posterior + observed_data --> df from posterior

observed_data -> TypeError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok, given that it fits our goals and we document it. My only question not really knowing yet how panel works is if ppc checks will be possible with it if observed_data is not in the dataframe

Copy link
Contributor Author

@ahartikainen ahartikainen Sep 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Panel works with xarray too. My gist was just short cutting some code with data frame.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, cool!


dfs = {}
for group in group_names:
dataset = self[group]
df = None
coords_to_idx = {
name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin)))
for name in list(filter(lambda x: x not in ("chain", "draw"), dataset.coords))
}
for data_array in dataset.values():
dataframe = data_array.to_dataframe()
if list(filter(lambda x: x not in ("chain", "draw"), data_array.dims)):
levels = [
idx
for idx, dim in enumerate(data_array.dims)
if dim not in ("chain", "draw")
]
dataframe = dataframe.unstack(level=levels)
tuple_columns = []
for name, *coords in dataframe.columns:
if include_index:
idxs = []
for coordname, coorditem in zip(dataframe.columns.names[1:], coords):
idxs.append(coords_to_idx[coordname][coorditem])
if include_coords:
tuple_columns.append(
("{}[{}]".format(name, ",".join(map(str, idxs))), *coords)
)
else:
tuple_columns.append(
"{}[{}]".format(name, ",".join(map(str, idxs)))
)
else:
tuple_columns.append((name, *coords))

dataframe.columns = tuple_columns
dataframe.sort_index(axis=1, inplace=True)
if df is None:
df = dataframe
continue
df = df.join(dataframe, how="outer")
df = df.reset_index()
dfs[group] = df
if len(dfs) > 1:
for group, df in dfs.items():
df.columns = [
col
if col in ("draw", "chain")
else (group, col)
if not isinstance(col, tuple)
else (group, *col)
for col in df.columns
]
dfs, *dfs_tail = list(dfs.values())
for df in dfs_tail:
dfs = dfs.merge(df, how="outer", copy=False)
else:
(dfs,) = dfs.values()
return dfs

def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)
Expand Down
77 changes: 77 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,83 @@ def test_to_dict_warmup(self):
test_xr_data = getattr(test_data, group)
assert xr_data.equals(test_xr_data)

@pytest.mark.parametrize(
"kwargs",
(
{
"groups": "posterior",
"include_coords": True,
"include_index": True,
"index_origin": 0,
},
{
"groups": ["posterior", "sample_stats"],
"include_coords": False,
"include_index": True,
"index_origin": 0,
},
{
"groups": "posterior_groups",
"include_coords": True,
"include_index": False,
"index_origin": 1,
},
),
)
def test_to_dataframe(self, kwargs):
idata = from_dict(
posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
)
test_data = idata.to_dataframe(**kwargs)
assert not test_data.empty
groups = kwargs.get("groups", idata._groups_all) # pylint: disable=protected-access
for group in idata._groups_all: # pylint: disable=protected-access
if "data" in group:
continue
assert test_data.shape == (
(4 * 100, 3 * 4 * 5 + 1 + 2)
if groups == "posterior"
else (4 * 100, (3 * 4 * 5 + 1) * 2 + 2)
)
if groups == "posterior":
if kwargs.get("include_coords", True) and kwargs.get("include_index", True):
assert any(
"[{},".format(kwargs.get("index_origin", 0)) in item[0]
for item in test_data.columns
if isinstance(item, tuple)
)
if kwargs.get("include_coords", True):
assert any(isinstance(item, tuple) for item in test_data.columns)
else:
assert not any(isinstance(item, tuple) for item in test_data.columns)
else:
if not kwargs.get("include_index", True):
assert all(
item in test_data.columns
for item in (("posterior", "a", 1, 1, 1), ("posterior", "b"))
)
assert all(item in test_data.columns for item in ("chain", "draw"))

def test_to_dataframe_bad(self):
idata = from_dict(
posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
)
with pytest.raises(TypeError):
idata.to_dataframe(index_origin=2)

with pytest.raises(TypeError):
idata.to_dataframe(include_coords=False, include_index=False)

with pytest.raises(TypeError):
idata.to_dataframe(groups=["observed_data"])

with pytest.raises(KeyError):
idata.to_dataframe(groups=["invalid_group"])

@pytest.mark.parametrize("use", (None, "args", "kwargs"))
def test_map(self, use):
idata = load_arviz_data("centered_eight")
Expand Down