-
-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add to_dataframe to InferenceData #1395
Changes from all commits
d86e25c
605e273
c95c29e
0b9e520
8c92e4d
7aa76d7
0aea43d
fa29104
344cbc4
08c22d8
4eddd6a
b3c6e28
b776912
d11c6d2
74d9dfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -190,6 +190,12 @@ def __iter__(self): | |
for group in self._groups_all: | ||
yield group | ||
|
||
def __getitem__(self, key): | ||
"""Get item by key.""" | ||
if key not in self._groups_all: | ||
raise KeyError(key) | ||
return getattr(self, key) | ||
|
||
def groups(self): | ||
"""Return all groups present in InferenceData object.""" | ||
return self._groups_all | ||
|
@@ -278,7 +284,13 @@ def to_dict(self, groups=None, filter_groups=None): | |
Parameters | ||
---------- | ||
groups : list, optional | ||
Write only these groups to netcdf file. | ||
Groups where the transformation is to be applied. Can either be group names | ||
or metagroup names. | ||
filter_groups: {None, "like", "regex"}, optional, default=None | ||
If `None` (default), interpret groups as the real group or metagroup names. | ||
If "like", interpret groups as substrings of the real group or metagroup names. | ||
If "regex", interpret groups as regular expressions on the real group or | ||
metagroup names. A la `pandas.filter`. | ||
|
||
Returns | ||
------- | ||
|
@@ -328,13 +340,21 @@ def to_dict(self, groups=None, filter_groups=None): | |
ret["attrs"] = attrs | ||
return ret | ||
|
||
def to_json(self, filename, **kwargs): | ||
def to_json(self, filename, groups=None, filter_groups=None, **kwargs): | ||
"""Write InferenceData to a json file. | ||
|
||
Parameters | ||
---------- | ||
filename : str | ||
Location to write to | ||
groups : list, optional | ||
Groups where the transformation is to be applied. Can either be group names | ||
or metagroup names. | ||
filter_groups: {None, "like", "regex"}, optional, default=None | ||
If `None` (default), interpret groups as the real group or metagroup names. | ||
If "like", interpret groups as substrings of the real group or metagroup names. | ||
If "regex", interpret groups as regular expressions on the real group or | ||
metagroup names. A la `pandas.filter`. | ||
kwargs : dict | ||
kwargs passed to json.dump() | ||
|
||
|
@@ -343,13 +363,132 @@ def to_json(self, filename, **kwargs): | |
str | ||
Location of json file | ||
""" | ||
idata_dict = _make_json_serializable(self.to_dict()) | ||
idata_dict = _make_json_serializable( | ||
self.to_dict(groups=groups, filter_groups=filter_groups) | ||
) | ||
|
||
with open(filename, "w") as file: | ||
json.dump(idata_dict, file, **kwargs) | ||
|
||
return filename | ||
|
||
def to_dataframe( | ||
self, | ||
groups=None, | ||
filter_groups=None, | ||
include_coords=True, | ||
include_index=True, | ||
index_origin=None, | ||
): | ||
"""Convert InferenceData to a pandas DataFrame following xarray naming conventions. | ||
|
||
This returns dataframe in a "wide" -format, where each item in ndimensional array is | ||
unpacked. To access "tidy" -format, use xarray functionality found for each dataset. | ||
|
||
In case of a multiple groups, function adds a group identification to the var name. | ||
|
||
Data groups ("observed_data", "constant_data", "predictions_constant_data") are | ||
skipped implicitly. | ||
|
||
Raises TypeError if no valid groups are found. | ||
|
||
Parameters | ||
---------- | ||
groups: str or list of str, optional | ||
Groups where the transformation is to be applied. Can either be group names | ||
or metagroup names. | ||
filter_groups: {None, "like", "regex"}, optional, default=None | ||
If `None` (default), interpret groups as the real group or metagroup names. | ||
If "like", interpret groups as substrings of the real group or metagroup names. | ||
If "regex", interpret groups as regular expressions on the real group or | ||
metagroup names. A la `pandas.filter`. | ||
include_coords: bool | ||
Add coordinate values to column name (tuple). | ||
include_index: bool | ||
Add index information for multidimensional arrays. | ||
index_origin: {0, 1}, optional | ||
Starting index for multidimensional objects. 0- or 1-based. | ||
Defaults to rcParams["data.index_origin"]. | ||
|
||
Returns | ||
------- | ||
pandas.DataFrame | ||
A pandas DataFrame containing all selected groups of InferenceData object. | ||
""" | ||
# pylint: disable=too-many-nested-blocks | ||
if not include_coords and not include_index: | ||
raise TypeError("Both include_coords and include_index can not be False.") | ||
if index_origin is None: | ||
index_origin = rcParams["data.index_origin"] | ||
if index_origin not in [0, 1]: | ||
raise TypeError("index_origin must be 0 or 1, saw {}".format(index_origin)) | ||
ahartikainen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
group_names = list( | ||
filter(lambda x: "data" not in x, self._group_names(groups, filter_groups)) | ||
) | ||
|
||
if not group_names: | ||
raise TypeError("No valid groups found: {}".format(groups)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this logic ok? posterior + observed_data --> df from posterior observed_data -> TypeError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is ok, given that it fits our goals and we document it. My only question not really knowing yet how panel works is if ppc checks will be possible with it if observed_data is not in the dataframe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Panel works with xarray too. My gist was just short cutting some code with data frame. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, cool! |
||
|
||
dfs = {} | ||
for group in group_names: | ||
dataset = self[group] | ||
df = None | ||
coords_to_idx = { | ||
name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin))) | ||
for name in list(filter(lambda x: x not in ("chain", "draw"), dataset.coords)) | ||
} | ||
for data_array in dataset.values(): | ||
dataframe = data_array.to_dataframe() | ||
if list(filter(lambda x: x not in ("chain", "draw"), data_array.dims)): | ||
levels = [ | ||
idx | ||
for idx, dim in enumerate(data_array.dims) | ||
if dim not in ("chain", "draw") | ||
] | ||
dataframe = dataframe.unstack(level=levels) | ||
tuple_columns = [] | ||
for name, *coords in dataframe.columns: | ||
if include_index: | ||
idxs = [] | ||
for coordname, coorditem in zip(dataframe.columns.names[1:], coords): | ||
idxs.append(coords_to_idx[coordname][coorditem]) | ||
if include_coords: | ||
tuple_columns.append( | ||
("{}[{}]".format(name, ",".join(map(str, idxs))), *coords) | ||
) | ||
else: | ||
tuple_columns.append( | ||
"{}[{}]".format(name, ",".join(map(str, idxs))) | ||
) | ||
else: | ||
tuple_columns.append((name, *coords)) | ||
|
||
dataframe.columns = tuple_columns | ||
dataframe.sort_index(axis=1, inplace=True) | ||
if df is None: | ||
df = dataframe | ||
continue | ||
df = df.join(dataframe, how="outer") | ||
df = df.reset_index() | ||
dfs[group] = df | ||
if len(dfs) > 1: | ||
for group, df in dfs.items(): | ||
df.columns = [ | ||
col | ||
if col in ("draw", "chain") | ||
else (group, col) | ||
if not isinstance(col, tuple) | ||
else (group, *col) | ||
for col in df.columns | ||
] | ||
dfs, *dfs_tail = list(dfs.values()) | ||
for df in dfs_tail: | ||
dfs = dfs.merge(df, how="outer", copy=False) | ||
else: | ||
(dfs,) = dfs.values() | ||
return dfs | ||
|
||
def __add__(self, other): | ||
"""Concatenate two InferenceData objects.""" | ||
return concat(self, other, copy=True, inplace=False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️