diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a846c1b8a01..378e6330352 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,7 +32,10 @@ New Features - :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`) By `Ignacio Martinez Vazquez `_. - Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg - `create_index=False`. (:pull:`8960`) + `create_index_for_new_dim=False`. (:pull:`8960`) + By `Tom Nicholas `_. +- Avoid automatically re-creating 1D pandas indexes in :py:func:`concat()`. Also added option to avoid creating 1D indexes for + new dimension coordinates by passing the new kwarg `create_index_for_new_dim=False`. (:issue:`8871`, :pull:`8872`) By `Tom Nicholas `_. Breaking changes diff --git a/xarray/core/concat.py b/xarray/core/concat.py index d95cbccd36a..b1cca586992 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -8,6 +8,7 @@ from xarray.core import dtypes, utils from xarray.core.alignment import align, reindex_variables +from xarray.core.coordinates import Coordinates from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex from xarray.core.merge import ( @@ -42,6 +43,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_Dataset: ... @@ -56,6 +58,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_DataArray: ... @@ -69,6 +72,7 @@ def concat( fill_value=dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ): """Concatenate xarray objects along a new or existing dimension. @@ -162,6 +166,8 @@ def concat( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. + create_index_for_new_dim : bool, default: True + Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``. Returns ------- @@ -217,6 +223,25 @@ def concat( x (new_dim) >> ds = xr.Dataset(coords={"x": 0}) + >>> xr.concat([ds, ds], dim="x") + Size: 16B + Dimensions: (x: 2) + Coordinates: + * x (x) int64 16B 0 0 + Data variables: + *empty* + + >>> xr.concat([ds, ds], dim="x").indexes + Indexes: + x Index([0, 0], dtype='int64', name='x') + + >>> xr.concat([ds, ds], dim="x", create_index_for_new_dim=False).indexes + Indexes: + *empty* """ # TODO: add ignore_index arguments copied from pandas.concat # TODO: support concatenating scalar coordinates even if the concatenated @@ -245,6 +270,7 @@ def concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) elif isinstance(first_obj, Dataset): return _dataset_concat( @@ -257,6 +283,7 @@ def concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) else: raise TypeError( @@ -439,7 +466,7 @@ def _parse_datasets( if dim in dims: continue - if dim not in dim_coords: + if dim in ds.coords and dim not in dim_coords: dim_coords[dim] = ds.coords[dim].variable dims = dims | set(ds.dims) @@ -456,6 +483,7 @@ def _dataset_concat( fill_value: Any = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_Dataset: """ Concatenate a sequence of datasets along a new or existing dimension @@ -489,7 +517,6 @@ def _dataset_concat( datasets ) dim_names = set(dim_coords) - unlabeled_dims = dim_names - coord_names both_data_and_coords = coord_names & data_names if both_data_and_coords: @@ -502,7 +529,10 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - datasets = [ds.expand_dims(dim) for ds in datasets] + datasets = [ + ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim) + for ds in datasets + ] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -510,7 +540,7 @@ def _dataset_concat( ) # determine which variables to merge, and then merge them according to compat - variables_to_merge = (coord_names | data_names) - concat_over - unlabeled_dims + variables_to_merge = (coord_names | data_names) - concat_over result_vars = {} result_indexes = {} @@ -567,7 +597,8 @@ def get_indexes(name): var = ds._variables[name] if not var.dims: data = var.set_dims(dim).values - yield PandasIndex(data, dim, coord_dtype=var.dtype) + if create_index_for_new_dim: + yield PandasIndex(data, dim, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) @@ -646,29 +677,33 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - result = type(datasets[0])(result_vars, attrs=result_attrs) - - absent_coord_names = coord_names - set(result.variables) + absent_coord_names = coord_names - set(result_vars) if absent_coord_names: raise ValueError( f"Variables {absent_coord_names!r} are coordinates in some datasets but not others." ) - result = result.set_coords(coord_names) - result.encoding = result_encoding - result = result.drop_vars(unlabeled_dims, errors="ignore") + result_data_vars = {} + coord_vars = {} + for name, result_var in result_vars.items(): + if name in coord_names: + coord_vars[name] = result_var + else: + result_data_vars[name] = result_var if index is not None: - # add concat index / coordinate last to ensure that its in the final Dataset if dim_var is not None: index_vars = index.create_variables({dim: dim_var}) else: index_vars = index.create_variables() - result[dim] = index_vars[dim] + + coord_vars[dim] = index_vars[dim] result_indexes[dim] = index - # TODO: add indexes at Dataset creation (when it is supported) - result = result._overwrite_indexes(result_indexes) + coords_obj = Coordinates(coord_vars, indexes=result_indexes) + + result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs) + result.encoding = result_encoding return result @@ -683,6 +718,7 @@ def _dataarray_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_DataArray: from xarray.core.dataarray import DataArray @@ -719,6 +755,7 @@ def _dataarray_concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c89dedf1215..4dc897c1878 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2558,7 +2558,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, - create_index: bool = True, + create_index_for_new_dim: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -2569,7 +2569,7 @@ def expand_dims( coordinate consisting of a single value. The automatic creation of indexes to back new 1D coordinate variables - controlled by the create_index kwarg. + controlled by the create_index_for_new_dim kwarg. Parameters ---------- @@ -2586,8 +2586,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - create_index : bool, default is True - Whether to create new PandasIndex objects for any new 1D coordinate variables. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -2651,7 +2651,9 @@ def expand_dims( dim = {dim: 1} dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") - ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index) + ds = self._to_temp_dataset().expand_dims( + dim, axis, create_index_for_new_dim=create_index_for_new_dim + ) return self._from_temp_dataset(ds) def set_index( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2ddcacd2fa0..09597670573 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4513,7 +4513,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, - create_index: bool = True, + create_index_for_new_dim: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -4524,7 +4524,7 @@ def expand_dims( coordinate consisting of a single value. The automatic creation of indexes to back new 1D coordinate variables - controlled by the create_index kwarg. + controlled by the create_index_for_new_dim kwarg. Parameters ---------- @@ -4541,8 +4541,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - create_index : bool, default is True - Whether to create new PandasIndex objects for any new 1D coordinate variables. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -4612,6 +4612,33 @@ def expand_dims( Data variables: temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + # Expand a scalar variable along a new dimension of the same name with and without creating a new index + + >>> ds = xr.Dataset(coords={"x": 0}) + >>> ds + Size: 8B + Dimensions: () + Coordinates: + x int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x") + Size: 8B + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x").indexes + Indexes: + x Index([0], dtype='int64', name='x') + + >>> ds.expand_dims("x", create_index_for_new_dim=False).indexes + Indexes: + *empty* + See Also -------- DataArray.expand_dims @@ -4663,7 +4690,7 @@ def expand_dims( # value within the dim dict to the length of the iterable # for later use. - if create_index: + if create_index_for_new_dim: index = PandasIndex(v, k) indexes[k] = index name_and_new_1d_var = index.create_variables() @@ -4705,14 +4732,14 @@ def expand_dims( variables[k] = v.set_dims(dict(all_dims)) else: if k not in variables: - if k in coord_names and create_index: + if k in coord_names and create_index_for_new_dim: # If dims includes a label of a non-dimension coordinate, # it will be promoted to a 1D coordinate with a single value. index, index_vars = create_default_index_implicit(v.set_dims(k)) indexes[k] = index variables.update(index_vars) else: - if create_index: + if create_index_for_new_dim: warnings.warn( f"No index created for dimension {k} because variable {k} is not a coordinate. " f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.", @@ -5400,7 +5427,7 @@ def to_stacked_array( [3, 4, 5, 7]]) Coordinates: * z (z) object 32B MultiIndex - * variable (z) object 32B 'a' 'a' 'a' 'b' + * variable (z) 1: - raise UnexpectedDataAccess("Tried accessing more than one element.") - return self.array[tuple_idxr] - - -class DuckArrayWrapper(utils.NDArrayMixin): - """Array-like that prevents casting to array. - Modeled after cupy.""" - - def __init__(self, array: np.ndarray): - self.array = array - - def __getitem__(self, key): - return type(self)(self.array[key]) - - def __array__(self, dtype: np.typing.DTypeLike = None): - raise UnexpectedDataAccess("Tried accessing data") - - def __array_namespace__(self): - """Present to satisfy is_duck_array test.""" - - class ReturnItem: def __getitem__(self, key): return key diff --git a/xarray/tests/arrays.py b/xarray/tests/arrays.py new file mode 100644 index 00000000000..983e620d1f0 --- /dev/null +++ b/xarray/tests/arrays.py @@ -0,0 +1,179 @@ +from collections.abc import Iterable +from typing import Any, Callable + +import numpy as np + +from xarray.core import utils +from xarray.core.indexing import ExplicitlyIndexed + +""" +This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access. +""" + + +class UnexpectedDataAccess(Exception): + pass + + +class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): + """Disallows any loading.""" + + def __init__(self, array): + self.array = array + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key): + raise UnexpectedDataAccess("Tried accessing data.") + + +class FirstElementAccessibleArray(InaccessibleArray): + def __getitem__(self, key): + tuple_idxr = key.tuple + if len(tuple_idxr) > 1: + raise UnexpectedDataAccess("Tried accessing more than one element.") + return self.array[tuple_idxr] + + +class DuckArrayWrapper(utils.NDArrayMixin): + """Array-like that prevents casting to array. + Modeled after cupy.""" + + def __init__(self, array: np.ndarray): + self.array = array + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __array_namespace__(self): + """Present to satisfy is_duck_array test.""" + + +CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for ConcatenatableArray objects.""" + + def decorator(func): + CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.concatenate) +def concatenate( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.concatenate([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.stack) +def stack( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.stack([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.result_type) +def result_type(*arrays_and_dtypes) -> np.dtype: + """Called by xarray to ensure all arguments to concat have the same dtype.""" + first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes) + for other_dtype in other_dtypes: + if other_dtype != first_dtype: + raise ValueError("dtypes not all consistent") + return first_dtype + + +@implements(np.broadcast_to) +def broadcast_to( + x: "ConcatenatableArray", /, shape: tuple[int, ...] +) -> "ConcatenatableArray": + """ + Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries. + """ + if not isinstance(x, ConcatenatableArray): + raise TypeError + + result = np.broadcast_to(x._array, shape=shape) + return ConcatenatableArray(result) + + +class ConcatenatableArray: + """Disallows loading or coercing to an index but does support concatenation / stacking.""" + + def __init__(self, array): + # use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit) + self._array = array + + @property + def dtype(self: Any) -> np.dtype: + return self._array.dtype + + @property + def shape(self: Any) -> tuple[int, ...]: + return self._array.shape + + @property + def ndim(self: Any) -> int: + return self._array.ndim + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(array={self._array!r})" + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key) -> "ConcatenatableArray": + """Some cases of concat require supporting expanding dims by dimensions of size 1""" + # see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing + arr = self._array + for axis, indexer_1d in enumerate(key): + if indexer_1d is None: + arr = np.expand_dims(arr, axis) + elif indexer_1d is Ellipsis: + pass + else: + raise UnexpectedDataAccess("Tried accessing data.") + return ConcatenatableArray(arr) + + def __array_function__(self, func, types, args, kwargs) -> Any: + if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: + return NotImplemented + + # Note: this allows subclasses that don't override + # __array_function__ to handle ManifestArray objects + if not all(issubclass(t, ConcatenatableArray) for t in types): + return NotImplemented + + return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any: + """We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs.""" + return NotImplemented + + def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> "ConcatenatableArray": + """Needed because xarray will call this even when it's a no-op""" + if dtype != self.dtype: + raise NotImplementedError() + else: + return self diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 1ddb5a569bd..0c570de3b52 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -12,7 +12,9 @@ from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex from xarray.tests import ( + ConcatenatableArray, InaccessibleArray, + UnexpectedDataAccess, assert_array_equal, assert_equal, assert_identical, @@ -999,6 +1001,63 @@ def test_concat_str_dtype(self, dtype, dim) -> None: assert np.issubdtype(actual.x2.dtype, dtype) + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to Dataset constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + datasets = [ + Dataset( + {"a": (["x", "y"], ConcatenatableArray(np.zeros((3, 3))))}, + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(datasets, dim="x") + assert combined["a"].shape == (6, 3) + assert combined["a"].dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(datasets, dim="z") + assert combined["a"].shape == (2, 3, 3) + assert combined["a"].dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_avoids_index_auto_creation_new_1d_coord(self) -> None: + # create 0D coordinates (without indexes) + datasets = [ + Dataset( + coords={"x": ConcatenatableArray(np.array(10))}, + ) + for _ in range(2) + ] + + with pytest.raises(UnexpectedDataAccess): + concat(datasets, dim="x", create_index_for_new_dim=True) + + # should not raise on concat iff create_index_for_new_dim=False + combined = concat(datasets, dim="x", create_index_for_new_dim=False) + assert combined["x"].shape == (2,) + assert combined["x"].dims == ("x",) + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_promote_shape_without_creating_new_index(self) -> None: + # different shapes but neither have indexes + ds1 = Dataset(coords={"x": 0}) + ds2 = Dataset(data_vars={"x": [1]}).drop_indexes("x") + actual = concat([ds1, ds2], dim="x", create_index_for_new_dim=False) + expected = Dataset(data_vars={"x": [0, 1]}).drop_indexes("x") + assert_identical(actual, expected, check_default_indexes=False) + assert actual.indexes == {} + class TestConcatDataArray: def test_concat(self) -> None: @@ -1072,6 +1131,35 @@ def test_concat_lazy(self) -> None: assert combined.shape == (2, 3, 3) assert combined.dims == ("z", "x", "y") + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to DataArray constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + arrays = [ + DataArray( + ConcatenatableArray(np.zeros((3, 3))), + dims=["x", "y"], + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(arrays, dim="x") + assert combined.shape == (6, 3) + assert combined.dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(arrays, dim="z") + assert combined.shape == (2, 3, 3) + assert combined.dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_concat_fill_value(self, fill_value) -> None: foo = DataArray([1, 2], coords=[("x", [1, 2])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 59b5b2b9b71..584776197e3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3431,16 +3431,22 @@ def test_expand_dims_kwargs_python36plus(self) -> None: ) assert_identical(other_way_expected, other_way) - @pytest.mark.parametrize("create_index_flag", [True, False]) - def test_expand_dims_create_index_data_variable(self, create_index_flag): + @pytest.mark.parametrize("create_index_for_new_dim_flag", [True, False]) + def test_expand_dims_create_index_data_variable( + self, create_index_for_new_dim_flag + ): # data variables should not gain an index ever ds = Dataset({"x": 0}) - if create_index_flag: + if create_index_for_new_dim_flag: with pytest.warns(UserWarning, match="No index created"): - expanded = ds.expand_dims("x", create_index=create_index_flag) + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) else: - expanded = ds.expand_dims("x", create_index=create_index_flag) + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x") @@ -3449,13 +3455,13 @@ def test_expand_dims_create_index_data_variable(self, create_index_flag): assert expanded.indexes == {} def test_expand_dims_create_index_coordinate_variable(self): - # coordinate variables should gain an index only if create_index is True (the default) + # coordinate variables should gain an index only if create_index_for_new_dim is True (the default) ds = Dataset(coords={"x": 0}) expanded = ds.expand_dims("x") expected = Dataset({"x": ("x", [0])}) assert_identical(expanded, expected) - expanded_no_index = ds.expand_dims("x", create_index=False) + expanded_no_index = ds.expand_dims("x", create_index_for_new_dim=False) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x") @@ -3469,7 +3475,7 @@ def test_expand_dims_create_index_from_iterable(self): expected = Dataset({"x": ("x", [0, 1])}) assert_identical(expanded, expected) - expanded_no_index = ds.expand_dims(x=[0, 1], create_index=False) + expanded_no_index = ds.expand_dims(x=[0, 1], create_index_for_new_dim=False) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x") @@ -3971,6 +3977,25 @@ def test_to_stacked_array_to_unstacked_dataset_different_dimension(self) -> None x = y.to_unstacked_dataset("features") assert_identical(D, x) + def test_to_stacked_array_preserves_dtype(self) -> None: + # regression test for bug found in https://github.com/pydata/xarray/pull/8872#issuecomment-2081218616 + ds = xr.Dataset( + data_vars={ + "a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), + "b": ("x", [6, 7]), + }, + coords={"y": ["u", "v", "w"]}, + ) + stacked = ds.to_stacked_array("z", sample_dims=["x"]) + + # coordinate created from variables names should be of string dtype + data = np.array(["a", "a", "a", "b"], dtype=" None: data = create_test_data(seed=0) expected = data.copy()