diff --git a/zarr/meta.py b/zarr/meta.py index 181c6efa69..acad19b036 100644 --- a/zarr/meta.py +++ b/zarr/meta.py @@ -166,6 +166,9 @@ def encode_fill_value(v, dtype): if v is None: return v if dtype.kind == 'f': + # these cases are handled in the ZarrJsonEncoder now, + # but there may be cases where the metadata was not parsed + # with this encoder (?), so better to leave it here if np.isnan(v): return 'NaN' elif np.isposinf(v): diff --git a/zarr/tests/test_attrs.py b/zarr/tests/test_attrs.py index 5ad39ef1a9..18a8bdb543 100644 --- a/zarr/tests/test_attrs.py +++ b/zarr/tests/test_attrs.py @@ -4,6 +4,7 @@ import pytest +import numpy as np from zarr.attrs import Attributes from zarr.tests.util import CountingDict @@ -218,3 +219,15 @@ def test_caching_off(self): assert 'spam' not in a assert 10 == store.counter['__getitem__', 'attrs'] assert 3 == store.counter['__setitem__', 'attrs'] + + def test_special_values(self): + a = self.init_attributes(dict()) + + a['nan'] = np.nan + assert np.isnan(a['nan']) + + a['pinf'] = np.PINF + assert np.isposinf(a['pinf']) + + a['ninf'] = np.NINF + assert np.isneginf(a['ninf']) diff --git a/zarr/util.py b/zarr/util.py index 47b1082941..33cd9d8733 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -19,15 +19,44 @@ } +class ZarrJsonEncoder(json.JsonEncoder): + """Encode json input + """ + def default(self, obj): + if np.isnan(obj): + return "NaN" + elif np.isposinf(obj): + return "Infinity" + elif np.isneginf(obj): + return "-Infinity" + # we could also allow for passing numpy dtypes: + # if isinstance(obj, np.dtype): + # return obj.item() + return super().default(obj) + + +class ZarrJsonDecoder(json.JsonEncoder): + """Decode json input + """ + def default(self, obj): + if obj == "NaN": + return np.nan + elif obj == "Infinity": + return np.PINF + elif obj == "-Infinity": + return np.NINF + return super().default(obj) + + def json_dumps(o): """Write JSON in a consistent, human-readable way.""" return json.dumps(o, indent=4, sort_keys=True, ensure_ascii=True, - separators=(',', ': ')).encode('ascii') + separators=(',', ': '), cls=ZarrJsonEncoder).encode('ascii') def json_loads(s): """Read JSON in a consistent way.""" - return json.loads(ensure_text(s, 'ascii')) + return json.loads(ensure_text(s, 'ascii'), cls=ZarrJsonDecoder) def normalize_shape(shape):