Skip to content

Commit

Permalink
Compress only a subset of types (#2129)
Browse files Browse the repository at this point in the history
* dont compress object dtype

* Update CHANGELOG.md

* fix compression dtypes

* fix

* change name

* add test

* tmp disable fix

* tmp disable

* tmp fix

* fix

* fix

* enable fix
  • Loading branch information
ahartikainen authored Oct 12, 2022
1 parent e1903c0 commit 3bf643c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126))
* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103))
* Add testing module for labeller classes ([2095](https://github.com/arviz-devs/arviz/pull/2095))
* Skip compression for object dtype while creating a netcdf file ([2129](https://github.com/arviz-devs/arviz/pull/2129))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))
Expand Down
13 changes: 12 additions & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@
InferenceDataT = TypeVar("InferenceDataT", bound="InferenceData")


def _compressible_dtype(dtype):
"""Check basic dtypes for automatic compression."""
if dtype.kind == "V":
return all(_compressible_dtype(item) for item, _ in dtype.fields.values())
return dtype.kind in {"b", "i", "u", "f", "c", "S"}


class InferenceData(Mapping[str, xr.Dataset]):
"""Container for inference data storage using xarray.
Expand Down Expand Up @@ -422,7 +429,11 @@ def to_netcdf(
data = getattr(self, group)
kwargs = {}
if compress:
kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
kwargs["encoding"] = {
var_name: {"zlib": True}
for var_name, values in data.variables.items()
if _compressible_dtype(values.dtype)
}
data.to_netcdf(filename, mode=mode, group=group, **kwargs)
data.close()
mode = "a"
Expand Down
9 changes: 6 additions & 3 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def get_inference_data(self, data, eight_schools_params):
prior_predictive=data.obj,
sample_stats_prior=data.obj,
observed_data=eight_schools_params,
coords={"school": np.arange(8)},
coords={"school": np.array(["a" * i for i in range(8)], dtype="U")},
dims={"theta": ["school"], "eta": ["school"]},
)

Expand Down Expand Up @@ -1253,7 +1253,8 @@ def test_io_function(self, data, eight_schools_params):
assert not os.path.exists(filepath)

@pytest.mark.parametrize("groups_arg", [False, True])
def test_io_method(self, data, eight_schools_params, groups_arg):
@pytest.mark.parametrize("compress", [True, False])
def test_io_method(self, data, eight_schools_params, groups_arg, compress):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
Expand All @@ -1277,7 +1278,9 @@ def test_io_method(self, data, eight_schools_params, groups_arg):
assert not os.path.exists(filepath)
# InferenceData method
inference_data.to_netcdf(
filepath, groups=("posterior", "observed_data") if groups_arg else None
filepath,
groups=("posterior", "observed_data") if groups_arg else None,
compress=compress,
)

# assert file has been saved correctly
Expand Down

0 comments on commit 3bf643c

Please sign in to comment.