diff --git a/CHANGELOG.md b/CHANGELOG.md index 07c69d717b..b4f058bea5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## v0.x.x Unreleased ### New features +* Added `to_dataframe` method to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395)) +* Added `__getitem__` magic to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395)) ### Maintenance and fixes diff --git a/arviz/data/base.py b/arviz/data/base.py index 9ed8011a2e..d42878169b 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -301,7 +301,7 @@ def _make_json_serializable(data: dict) -> dict: for key, value in data.items(): try: json.dumps(value) - except TypeError: + except (TypeError, OverflowError): pass else: ret[key] = value diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 69d15e3d34..6141700281 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -190,6 +190,12 @@ def __iter__(self): for group in self._groups_all: yield group + def __getitem__(self, key): + """Get item by key.""" + if key not in self._groups_all: + raise KeyError(key) + return getattr(self, key) + def groups(self): """Return all groups present in InferenceData object.""" return self._groups_all @@ -278,7 +284,13 @@ def to_dict(self, groups=None, filter_groups=None): Parameters ---------- groups : list, optional - Write only these groups to netcdf file. + Groups where the transformation is to be applied. Can either be group names + or metagroup names. + filter_groups: {None, "like", "regex"}, optional, default=None + If `None` (default), interpret groups as the real group or metagroup names. + If "like", interpret groups as substrings of the real group or metagroup names. + If "regex", interpret groups as regular expressions on the real group or + metagroup names. A la `pandas.filter`. Returns ------- @@ -328,13 +340,21 @@ def to_dict(self, groups=None, filter_groups=None): ret["attrs"] = attrs return ret - def to_json(self, filename, **kwargs): + def to_json(self, filename, groups=None, filter_groups=None, **kwargs): """Write InferenceData to a json file. Parameters ---------- filename : str Location to write to + groups : list, optional + Groups where the transformation is to be applied. Can either be group names + or metagroup names. + filter_groups: {None, "like", "regex"}, optional, default=None + If `None` (default), interpret groups as the real group or metagroup names. + If "like", interpret groups as substrings of the real group or metagroup names. + If "regex", interpret groups as regular expressions on the real group or + metagroup names. A la `pandas.filter`. kwargs : dict kwargs passed to json.dump() @@ -343,13 +363,132 @@ def to_json(self, filename, **kwargs): str Location of json file """ - idata_dict = _make_json_serializable(self.to_dict()) + idata_dict = _make_json_serializable( + self.to_dict(groups=groups, filter_groups=filter_groups) + ) with open(filename, "w") as file: json.dump(idata_dict, file, **kwargs) return filename + def to_dataframe( + self, + groups=None, + filter_groups=None, + include_coords=True, + include_index=True, + index_origin=None, + ): + """Convert InferenceData to a pandas DataFrame following xarray naming conventions. + + This returns dataframe in a "wide" -format, where each item in ndimensional array is + unpacked. To access "tidy" -format, use xarray functionality found for each dataset. + + In case of a multiple groups, function adds a group identification to the var name. + + Data groups ("observed_data", "constant_data", "predictions_constant_data") are + skipped implicitly. + + Raises TypeError if no valid groups are found. + + Parameters + ---------- + groups: str or list of str, optional + Groups where the transformation is to be applied. Can either be group names + or metagroup names. + filter_groups: {None, "like", "regex"}, optional, default=None + If `None` (default), interpret groups as the real group or metagroup names. + If "like", interpret groups as substrings of the real group or metagroup names. + If "regex", interpret groups as regular expressions on the real group or + metagroup names. A la `pandas.filter`. + include_coords: bool + Add coordinate values to column name (tuple). + include_index: bool + Add index information for multidimensional arrays. + index_origin: {0, 1}, optional + Starting index for multidimensional objects. 0- or 1-based. + Defaults to rcParams["data.index_origin"]. + + Returns + ------- + pandas.DataFrame + A pandas DataFrame containing all selected groups of InferenceData object. + """ + # pylint: disable=too-many-nested-blocks + if not include_coords and not include_index: + raise TypeError("Both include_coords and include_index can not be False.") + if index_origin is None: + index_origin = rcParams["data.index_origin"] + if index_origin not in [0, 1]: + raise TypeError("index_origin must be 0 or 1, saw {}".format(index_origin)) + + group_names = list( + filter(lambda x: "data" not in x, self._group_names(groups, filter_groups)) + ) + + if not group_names: + raise TypeError("No valid groups found: {}".format(groups)) + + dfs = {} + for group in group_names: + dataset = self[group] + df = None + coords_to_idx = { + name: dict(map(reversed, enumerate(dataset.coords[name].values, index_origin))) + for name in list(filter(lambda x: x not in ("chain", "draw"), dataset.coords)) + } + for data_array in dataset.values(): + dataframe = data_array.to_dataframe() + if list(filter(lambda x: x not in ("chain", "draw"), data_array.dims)): + levels = [ + idx + for idx, dim in enumerate(data_array.dims) + if dim not in ("chain", "draw") + ] + dataframe = dataframe.unstack(level=levels) + tuple_columns = [] + for name, *coords in dataframe.columns: + if include_index: + idxs = [] + for coordname, coorditem in zip(dataframe.columns.names[1:], coords): + idxs.append(coords_to_idx[coordname][coorditem]) + if include_coords: + tuple_columns.append( + ("{}[{}]".format(name, ",".join(map(str, idxs))), *coords) + ) + else: + tuple_columns.append( + "{}[{}]".format(name, ",".join(map(str, idxs))) + ) + else: + tuple_columns.append((name, *coords)) + + dataframe.columns = tuple_columns + dataframe.sort_index(axis=1, inplace=True) + if df is None: + df = dataframe + continue + df = df.join(dataframe, how="outer") + df = df.reset_index() + dfs[group] = df + if len(dfs) > 1: + for group, df in dfs.items(): + df.columns = [ + col + if col in ("draw", "chain") + else (group, col) + if not isinstance(col, tuple) + else (group, *col) + for col in df.columns + ] + dfs, *dfs_tail = list(dfs.values()) + for df in dfs_tail: + dfs = dfs.merge(df, how="outer", copy=False) + else: + (dfs,) = dfs.values() + return dfs + def __add__(self, other): """Concatenate two InferenceData objects.""" return concat(self, other, copy=True, inplace=False) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 9278b35740..6de15f4a1f 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -593,6 +593,83 @@ def test_to_dict_warmup(self): test_xr_data = getattr(test_data, group) assert xr_data.equals(test_xr_data) + @pytest.mark.parametrize( + "kwargs", + ( + { + "groups": "posterior", + "include_coords": True, + "include_index": True, + "index_origin": 0, + }, + { + "groups": ["posterior", "sample_stats"], + "include_coords": False, + "include_index": True, + "index_origin": 0, + }, + { + "groups": "posterior_groups", + "include_coords": True, + "include_index": False, + "index_origin": 1, + }, + ), + ) + def test_to_dataframe(self, kwargs): + idata = from_dict( + posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)}, + sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)}, + observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)}, + ) + test_data = idata.to_dataframe(**kwargs) + assert not test_data.empty + groups = kwargs.get("groups", idata._groups_all) # pylint: disable=protected-access + for group in idata._groups_all: # pylint: disable=protected-access + if "data" in group: + continue + assert test_data.shape == ( + (4 * 100, 3 * 4 * 5 + 1 + 2) + if groups == "posterior" + else (4 * 100, (3 * 4 * 5 + 1) * 2 + 2) + ) + if groups == "posterior": + if kwargs.get("include_coords", True) and kwargs.get("include_index", True): + assert any( + "[{},".format(kwargs.get("index_origin", 0)) in item[0] + for item in test_data.columns + if isinstance(item, tuple) + ) + if kwargs.get("include_coords", True): + assert any(isinstance(item, tuple) for item in test_data.columns) + else: + assert not any(isinstance(item, tuple) for item in test_data.columns) + else: + if not kwargs.get("include_index", True): + assert all( + item in test_data.columns + for item in (("posterior", "a", 1, 1, 1), ("posterior", "b")) + ) + assert all(item in test_data.columns for item in ("chain", "draw")) + + def test_to_dataframe_bad(self): + idata = from_dict( + posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)}, + sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)}, + observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)}, + ) + with pytest.raises(TypeError): + idata.to_dataframe(index_origin=2) + + with pytest.raises(TypeError): + idata.to_dataframe(include_coords=False, include_index=False) + + with pytest.raises(TypeError): + idata.to_dataframe(groups=["observed_data"]) + + with pytest.raises(KeyError): + idata.to_dataframe(groups=["invalid_group"]) + @pytest.mark.parametrize("use", (None, "args", "kwargs")) def test_map(self, use): idata = load_arviz_data("centered_eight")