diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 89ceaddd93b..fa57bffb8d5 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -47,12 +47,11 @@ class EncodedStringCoder(VariableCoder): def __init__(self, allows_unicode=True): self.allows_unicode = allows_unicode - def encode(self, variable, name=None): + def encode(self, variable: Variable, name=None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) contains_unicode = is_unicode_dtype(data.dtype) encode_as_char = encoding.get("dtype") == "S1" - if encode_as_char: del encoding["dtype"] # no longer relevant @@ -69,9 +68,12 @@ def encode(self, variable, name=None): # TODO: figure out how to handle this in a lazy way with dask data = encode_string_array(data, string_encoding) - return Variable(dims, data, attrs, encoding) + return Variable(dims, data, attrs, encoding) + else: + variable.encoding = encoding + return variable - def decode(self, variable, name=None): + def decode(self, variable: Variable, name=None) -> Variable: dims, data, attrs, encoding = unpack_for_decoding(variable) if "_Encoding" in attrs: @@ -95,13 +97,15 @@ def encode_string_array(string_array, encoding="utf-8"): return np.array(encoded, dtype=bytes).reshape(string_array.shape) -def ensure_fixed_length_bytes(var): +def ensure_fixed_length_bytes(var: Variable) -> Variable: """Ensure that a variable with vlen bytes is converted to fixed width.""" - dims, data, attrs, encoding = unpack_for_encoding(var) - if check_vlen_dtype(data.dtype) == bytes: + if check_vlen_dtype(var.dtype) == bytes: + dims, data, attrs, encoding = unpack_for_encoding(var) # TODO: figure out how to handle this with dask data = np.asarray(data, dtype=np.bytes_) - return Variable(dims, data, attrs, encoding) + return Variable(dims, data, attrs, encoding) + else: + return var class CharacterArrayCoder(VariableCoder): diff --git a/xarray/conventions.py b/xarray/conventions.py index 8c7d6be2309..bf9f315c326 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -16,7 +16,7 @@ ) from xarray.core.pycompat import is_duck_dask_array from xarray.core.utils import emit_user_level_warning -from xarray.core.variable import IndexVariable, Variable +from xarray.core.variable import Variable CF_RELATED_DATA = ( "bounds", @@ -97,10 +97,10 @@ def _infer_dtype(array, name=None): def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: - if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex): + if isinstance(var._data, indexing.PandasMultiIndexingAdapter): raise NotImplementedError( f"variable {name!r} is a MultiIndex, which cannot yet be " - "serialized to netCDF files. Instead, either use reset_index() " + "serialized. Instead, either use reset_index() " "to convert MultiIndex levels into coordinate variables instead " "or use https://cf-xarray.readthedocs.io/en/latest/coding.html." ) @@ -647,7 +647,9 @@ def cf_decoder( return variables, attributes -def _encode_coordinates(variables, attributes, non_dim_coord_names): +def _encode_coordinates( + variables: T_Variables, attributes: T_Attrs, non_dim_coord_names +): # calculate global and variable specific coordinates non_dim_coord_names = set(non_dim_coord_names) @@ -675,7 +677,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): variable_coordinates[k].add(coord_name) if any( - attr_name in v.encoding and coord_name in v.encoding.get(attr_name) + coord_name in v.encoding.get(attr_name, tuple()) for attr_name in CF_RELATED_DATA ): not_technically_coordinates.add(coord_name) @@ -742,7 +744,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): return variables, attributes -def encode_dataset_coordinates(dataset): +def encode_dataset_coordinates(dataset: Dataset): """Encode coordinates on the given dataset object into variable specific and global attributes. @@ -764,7 +766,7 @@ def encode_dataset_coordinates(dataset): ) -def cf_encoder(variables, attributes): +def cf_encoder(variables: T_Variables, attributes: T_Attrs): """ Encode a set of CF encoded variables and attributes. Takes a dicts of variables and attributes and encodes them diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 94d3ea92af2..b9190fb4252 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -733,7 +733,7 @@ def test_encode_time_bounds() -> None: # if time_bounds attrs are same as time attrs, it doesn't matter ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) assert_equal(encoded["time_bounds"], expected["time_bounds"]) assert "calendar" not in encoded["time_bounds"].attrs assert "units" not in encoded["time_bounds"].attrs @@ -741,7 +741,7 @@ def test_encode_time_bounds() -> None: # for CF-noncompliant case of time_bounds attrs being different from # time attrs; preserve them for faithful roundtrip ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) with pytest.raises(AssertionError): assert_equal(encoded["time_bounds"], expected["time_bounds"]) assert "calendar" not in encoded["time_bounds"].attrs