Skip to content

Commit

Permalink
Add support for netCDF4.EnumType (#8147)
Browse files Browse the repository at this point in the history
  • Loading branch information
bzah authored Jan 17, 2024
1 parent 33d51c8 commit d20ba0d
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 25 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ New Features

- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata.
If multiple variables share the same enum in netCDF4, each dataarray will have its own
enum definition in their respective dtype metadata.
By `Abel Aoun <https://github.com/bzah>_`(:issue:`8144`, :pull:`8147`)
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
By `Ben Mares <https://github.com/maresb>`_.
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the
Expand Down
74 changes: 58 additions & 16 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
# string used by netCDF4.
_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}


NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])


Expand Down Expand Up @@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype):
)


def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
def _get_datatype(
var, nc_format="NETCDF4", raise_on_invalid_encoding=False
) -> np.dtype:
if nc_format == "NETCDF4":
return _nc4_dtype(var)
if "dtype" in var.encoding:
Expand Down Expand Up @@ -234,13 +235,13 @@ def _force_native_endianness(var):


def _extract_nc4_variable_encoding(
variable,
variable: Variable,
raise_on_invalid=False,
lsd_okay=True,
h5py_okay=False,
backend="netCDF4",
unlimited_dims=None,
):
) -> dict[str, Any]:
if unlimited_dims is None:
unlimited_dims = ()

Expand Down Expand Up @@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
return encoding


def _is_list_of_strings(value):
def _is_list_of_strings(value) -> bool:
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1

Expand Down Expand Up @@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True):
def ds(self):
return self._acquire()

def open_store_variable(self, name, var):
def open_store_variable(self, name: str, var):
import netCDF4

dimensions = var.dimensions
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
encoding: dict[str, Any] = {}
if isinstance(var.datatype, netCDF4.EnumType):
encoding["dtype"] = np.dtype(
data.dtype,
metadata={
"enum": var.datatype.enum_dict,
"enum_name": var.datatype.name,
},
)
else:
encoding["dtype"] = var.dtype
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
encoding = {}
filters = var.filters()
if filters is not None:
encoding.update(filters)
Expand All @@ -440,7 +453,6 @@ def open_store_variable(self, name, var):
# save source so __repr__ can detect if it's local or not
encoding["source"] = self._filename
encoding["original_shape"] = var.shape
encoding["dtype"] = var.dtype

return Variable(dimensions, data, attributes, encoding)

Expand Down Expand Up @@ -485,21 +497,24 @@ def encode_variable(self, variable):
return variable

def prepare_variable(
self, name, variable, check_encoding=False, unlimited_dims=None
self, name, variable: Variable, check_encoding=False, unlimited_dims=None
):
_ensure_no_forward_slash_in_name(name)

attrs = variable.attrs.copy()
fill_value = attrs.pop("_FillValue", None)
datatype = _get_datatype(
variable, self.format, raise_on_invalid_encoding=check_encoding
)
attrs = variable.attrs.copy()

fill_value = attrs.pop("_FillValue", None)

# check enum metadata and use netCDF4.EnumType
if (
(meta := np.dtype(datatype).metadata)
and (e_name := meta.get("enum_name"))
and (e_dict := meta.get("enum"))
):
datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
encoding = _extract_nc4_variable_encoding(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)

if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
Expand Down Expand Up @@ -527,6 +542,33 @@ def prepare_variable(

return target, variable.data

def _build_and_get_enum(
self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
) -> Any:
"""
Add or get the netCDF4 Enum based on the dtype in encoding.
The return type should be ``netCDF4.EnumType``,
but we avoid importing netCDF4 globally for performances.
"""
if enum_name not in self.ds.enumtypes:
return self.ds.createEnumType(
dtype,
enum_name,
enum_dict,
)
datatype = self.ds.enumtypes[enum_name]
if datatype.enum_dict != enum_dict:
error_msg = (
f"Cannot save variable `{var_name}` because an enum"
f" `{enum_name}` already exists in the Dataset but have"
" a different definition. To fix this error, make sure"
" each variable have a uniquely named enum in their"
" `encoding['dtype'].metadata` or, if they should share"
" the same enum type, make sure the enums are identical."
)
raise ValueError(error_msg)
return datatype

def sync(self):
self.ds.sync()

Expand Down
21 changes: 20 additions & 1 deletion xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,30 @@ def decode(self):

class ObjectVLenStringCoder(VariableCoder):
def encode(self):
return NotImplementedError
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
variable = variable.astype(variable.encoding["dtype"])
return variable
else:
return variable


class NativeEnumCoder(VariableCoder):
"""Encode Enum into variable dtype metadata."""

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if (
"dtype" in variable.encoding
and np.dtype(variable.encoding["dtype"]).metadata
and "enum" in variable.encoding["dtype"].metadata
):
dims, data, attrs, encoding = unpack_for_encoding(variable)
data = data.astype(dtype=variable.encoding.pop("dtype"))
return Variable(dims, data, attrs, encoding, fastpath=True)
else:
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
raise NotImplementedError()
17 changes: 9 additions & 8 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


def _var_as_tuple(var: Variable) -> T_VarTuple:
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
Expand Down Expand Up @@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
dims, data, attrs, encoding = _var_as_tuple(var)
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
Expand Down Expand Up @@ -162,7 +158,7 @@ def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
"""
Converts an Variable into an Variable which follows some
Converts a Variable into a Variable which follows some
of the CF conventions:
- Nans are masked using _FillValue (or the deprecated missing_value)
Expand All @@ -188,6 +184,7 @@ def encode_cf_variable(
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(),
variables.UnsignedIntegerCoder(),
variables.NativeEnumCoder(),
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
Expand Down Expand Up @@ -447,7 +444,7 @@ def stackable(dim: Hashable) -> bool:
decode_timedelta=decode_timedelta,
)
except Exception as e:
raise type(e)(f"Failed to decode variable {k!r}: {e}")
raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
if decode_coords in [True, "coordinates", "all"]:
var_attrs = new_vars[k].attrs
if "coordinates" in var_attrs:
Expand Down Expand Up @@ -633,7 +630,11 @@ def cf_decoder(
decode_cf_variable
"""
variables, attributes, _ = decode_cf_variables(
variables, attributes, concat_characters, mask_and_scale, decode_times
variables,
attributes,
concat_characters,
mask_and_scale,
decode_times,
)
return variables, attributes

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4062,6 +4062,9 @@ def to_netcdf(
name is the same as a coordinate name, then it is given the name
``"__xarray_dataarray_variable__"``.
[netCDF4 backend only] netCDF4 enums are decoded into the
dataarray dtype metadata.
See Also
--------
Dataset.to_netcdf
Expand Down
120 changes: 120 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,126 @@ def test_raise_on_forward_slashes_in_names(self) -> None:
with self.roundtrip(ds):
pass

@requires_netCDF4
def test_encoding_enum__no_fill_value(self):
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
v = nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=None,
)
v[:] = 1
with open_dataset(tmp_file) as original:
save_kwargs = {}
if self.engine == "h5netcdf":
save_kwargs["invalid_netcdf"] = True
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
assert_equal(original, actual)
assert (
actual.clouds.encoding["dtype"].metadata["enum"]
== cloud_type_dict
)
if self.engine != "h5netcdf":
# not implemented in h5netcdf yet
assert (
actual.clouds.encoding["dtype"].metadata["enum_name"]
== "cloud_type"
)

@requires_netCDF4
def test_encoding_enum__multiple_variable_with_enum(self):
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=255,
)
nc.createVariable(
"tifa",
cloud_type,
"time",
fill_value=255,
)
with open_dataset(tmp_file) as original:
save_kwargs = {}
if self.engine == "h5netcdf":
save_kwargs["invalid_netcdf"] = True
with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
assert_equal(original, actual)
assert (
actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"]
)
assert (
actual.clouds.encoding["dtype"].metadata
== actual.tifa.encoding["dtype"].metadata
)
assert (
actual.clouds.encoding["dtype"].metadata["enum"]
== cloud_type_dict
)
if self.engine != "h5netcdf":
# not implemented in h5netcdf yet
assert (
actual.clouds.encoding["dtype"].metadata["enum_name"]
== "cloud_type"
)

@requires_netCDF4
def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
"""
Given 2 variables, if they share the same enum type,
the 2 enum definition should be identical.
"""
with create_tmp_file() as tmp_file:
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
with nc4.Dataset(tmp_file, mode="w") as nc:
nc.createDimension("time", size=2)
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
nc.createVariable(
"clouds",
cloud_type,
"time",
fill_value=255,
)
nc.createVariable(
"tifa",
cloud_type,
"time",
fill_value=255,
)
with open_dataset(tmp_file) as original:
assert (
original.clouds.encoding["dtype"].metadata
== original.tifa.encoding["dtype"].metadata
)
modified_enum = original.clouds.encoding["dtype"].metadata["enum"]
modified_enum.update({"neblig": 2})
original.clouds.encoding["dtype"] = np.dtype(
"u1",
metadata={"enum": modified_enum, "enum_name": "cloud_type"},
)
if self.engine != "h5netcdf":
# not implemented yet in h5netcdf
with pytest.raises(
ValueError,
match=(
"Cannot save variable .*"
" because an enum `cloud_type` already exists in the Dataset .*"
),
):
with self.roundtrip(original):
pass


@requires_netCDF4
class TestNetCDF4Data(NetCDF4Base):
Expand Down

0 comments on commit d20ba0d

Please sign in to comment.