diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index d2bbc459d83..ff2ecbc74a1 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import JoinOptions, T_DataArray, T_Dataset + from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray def reindex_variables( @@ -173,7 +173,7 @@ def __init__( def _normalize_indexes( self, - indexes: Mapping[Any, Any], + indexes: Mapping[Any, Any | T_DuckArray], ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes/indexers used for re-indexing or alignment. @@ -194,7 +194,7 @@ def _normalize_indexes( f"Indexer has dimensions {idx.dims} that are different " f"from that to be indexed along '{k}'" ) - data = as_compatible_data(idx) + data: T_DuckArray = as_compatible_data(idx) pd_idx = safe_cast_to_index(data) pd_idx.name = k if isinstance(pd_idx, pd.MultiIndex): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2090a8ef989..48e25f7e1c7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7481,7 +7481,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: else: variables[k] = f(v, *args, **kwargs) if keep_attrs: - variables[k].attrs = v._attrs + variables[k]._attrs = v._attrs attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, attrs=attrs) diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 26efc5fc412..333059e00ae 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -25,7 +25,7 @@ T_ChunkedArray = TypeVar("T_ChunkedArray") if TYPE_CHECKING: - from xarray.core.types import T_Chunks, T_NormalizedChunks + from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks @functools.lru_cache(maxsize=1) @@ -257,7 +257,7 @@ def normalize_chunks( @abstractmethod def from_array( - self, data: np.ndarray, chunks: T_Chunks, **kwargs + self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs ) -> T_ChunkedArray: """ Create a chunked array from a non-chunked numpy-like array. diff --git a/xarray/core/types.py b/xarray/core/types.py index 6c15d666f1c..f80c2c52cd7 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -161,6 +161,10 @@ def copy( T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") T_Alignable = TypeVar("T_Alignable", bound="Alignable") +# Temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more: +T_DuckArray = TypeVar("T_DuckArray", bound=Any) + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index bd0ca57f33c..ad86b2c7fec 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -73,7 +73,7 @@ import pandas as pd if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims + from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray K = TypeVar("K") V = TypeVar("V") @@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]: return isinstance(value, (list, tuple)) -def is_duck_array(value: Any) -> bool: +def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]: if isinstance(value, np.ndarray): return True return ( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 670c3179c6c..f459f044751 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,7 +8,7 @@ from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast import numpy as np import pandas as pd @@ -66,6 +66,7 @@ PadModeOptions, PadReflectOptions, QuantileMethods, + T_DuckArray, T_Variable, ) @@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None) -> Variable | IndexVariable: +def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: """Convert an object into a Variable. Parameters @@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable: elif isinstance(obj, (set, dict)): raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") elif name is not None: - data = as_compatible_data(obj) + data: T_DuckArray = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " @@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data): return data -def as_compatible_data(data, fastpath: bool = False): +def as_compatible_data( + data: T_DuckArray | ArrayLike, fastpath: bool = False +) -> T_DuckArray: """Prepare and wrap data to put in a Variable. - If data does not have the necessary attributes, convert it to ndarray. @@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False): """ if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) from xarray.core.dataarray import DataArray @@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False): if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) if isinstance(data, tuple): data = utils.to_0d_object_array(data) @@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False): if not isinstance(data, np.ndarray) and ( hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") ): - return data + return cast("T_DuckArray", data) # validate whether the data is valid data types. data = np.asarray(data) @@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): __slots__ = ("_dims", "_data", "_attrs", "_encoding") - def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + def __init__( + self, + dims, + data: T_DuckArray | ArrayLike, + attrs=None, + encoding=None, + fastpath=False, + ): """ Parameters ---------- @@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): Well-behaved code to serialize a Variable should ignore unrecognized encoding items. """ - self._data = as_compatible_data(data, fastpath=fastpath) + self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) self._dims = self._parse_dimensions(dims) - self._attrs = None + self._attrs: dict[Any, Any] | None = None self._encoding = None if attrs is not None: self.attrs = attrs @@ -410,7 +420,7 @@ def _in_memory(self): ) @property - def data(self) -> Any: + def data(self: T_Variable): """ The Variable's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. @@ -429,12 +439,12 @@ def data(self) -> Any: return self.values @data.setter - def data(self, data): + def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) - if data.shape != self.shape: + if data.shape != self.shape: # type: ignore[attr-defined] raise ValueError( f"replacement data must match the Variable's shape. " - f"replacement data has shape {data.shape}; Variable has shape {self.shape}" + f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined] ) self._data = data @@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable: return self._replace(encoding={}) def copy( - self: T_Variable, deep: bool = True, data: ArrayLike | None = None + self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None ) -> T_Variable: """Returns a copy of this object. @@ -1058,24 +1068,26 @@ def copy( def _copy( self: T_Variable, deep: bool = True, - data: ArrayLike | None = None, + data: T_DuckArray | ArrayLike | None = None, memo: dict[int, Any] | None = None, ) -> T_Variable: if data is None: - ndata = self._data + data_old = self._data - if isinstance(ndata, indexing.MemoryCachedArray): + if isinstance(data_old, indexing.MemoryCachedArray): # don't share caching between copies - ndata = indexing.MemoryCachedArray(ndata.array) + ndata = indexing.MemoryCachedArray(data_old.array) + else: + ndata = data_old if deep: ndata = copy.deepcopy(ndata, memo) else: ndata = as_compatible_data(data) - if self.shape != ndata.shape: + if self.shape != ndata.shape: # type: ignore[attr-defined] raise ValueError( - f"Data shape {ndata.shape} must match shape of object {self.shape}" + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] ) attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) @@ -1248,11 +1260,11 @@ def chunk( inline_array=inline_array, ) - data = self._data - if chunkmanager.is_chunked_array(data): - data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type] + data_old = self._data + if chunkmanager.is_chunked_array(data_old): + data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] else: - if isinstance(data, indexing.ExplicitlyIndexed): + if isinstance(data_old, indexing.ExplicitlyIndexed): # Unambiguously handle array storage backends (like NetCDF4 and h5py) # that can't handle general array indexing. For example, in netCDF4 you # can do "outer" indexing along two dimensions independent, which works @@ -1261,20 +1273,22 @@ def chunk( # Using OuterIndexer is a pragmatic choice: dask does not yet handle # different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 - data = indexing.ImplicitToExplicitIndexingAdapter( - data, indexing.OuterIndexer + ndata = indexing.ImplicitToExplicitIndexingAdapter( + data_old, indexing.OuterIndexer ) + else: + ndata = data_old if utils.is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape)) + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) - data = chunkmanager.from_array( - data, + data_chunked = chunkmanager.from_array( + ndata, chunks, # type: ignore[arg-type] **_from_array_kwargs, ) - return self._replace(data=data) + return self._replace(data=data_chunked) def to_numpy(self) -> np.ndarray: """Coerces wrapped data to numpy and returns a numpy.ndarray""" diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index fb917dfb254..6a8cd9c457b 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -126,7 +126,8 @@ def test_dask_distributed_write_netcdf_with_dimensionless_variables( @requires_cftime @requires_netCDF4 -def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) @@ -135,9 +136,55 @@ def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): - for parallel in (False, True): - with xr.open_mfdataset(file_path, parallel=parallel) as tf: - assert_identical(tf["test"], da) + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +# TODO: move this to test_backends.py +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) diff --git a/xarray/tests/test_parallelcompat.py b/xarray/tests/test_parallelcompat.py index 2c3378a2816..ea324cafb76 100644 --- a/xarray/tests/test_parallelcompat.py +++ b/xarray/tests/test_parallelcompat.py @@ -12,7 +12,7 @@ guess_chunkmanager, list_chunkmanagers, ) -from xarray.core.types import T_Chunks, T_NormalizedChunks +from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks from xarray.tests import has_dask, requires_dask @@ -76,7 +76,7 @@ def normalize_chunks( return normalize_chunks(chunks, shape, limit, dtype, previous_chunks) def from_array( - self, data: np.ndarray, chunks: T_Chunks, **kwargs + self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs ) -> DummyChunkedArray: from dask import array as da diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 3c40d0a2361..4fcd5f98d8f 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -5,6 +5,7 @@ from copy import copy, deepcopy from datetime import datetime, timedelta from textwrap import dedent +from typing import Generic import numpy as np import pandas as pd @@ -26,6 +27,7 @@ VectorizedIndexer, ) from xarray.core.pycompat import array_type +from xarray.core.types import T_DuckArray from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable from xarray.tests import ( @@ -2529,7 +2531,7 @@ def test_to_index_variable_copy(self) -> None: assert a.dims == ("x",) -class TestAsCompatibleData: +class TestAsCompatibleData(Generic[T_DuckArray]): def test_unchanged_types(self): types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) for t in types: @@ -2610,17 +2612,17 @@ def test_tz_datetime(self) -> None: times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(times_s) + actual: T_DuckArray = as_compatible_data(times_s) assert actual.array == times_s assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz) series = pd.Series(times_s) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(series) + actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual, series.values) - assert actual.dtype == np.dtype("datetime64[ns]") + np.testing.assert_array_equal(actual2, series.values) + assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: # For more thorough tests, see test_variable.py