Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T_DuckArray type hint to Variable.data #8203

Merged
merged 22 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray


def reindex_variables(
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(

def _normalize_indexes(
self,
indexes: Mapping[Any, Any],
indexes: Mapping[Any, Any | T_DuckArray],
) -> tuple[NormalizedIndexes, NormalizedIndexVars]:
"""Normalize the indexes/indexers used for re-indexing or alignment.

Expand All @@ -194,7 +194,7 @@ def _normalize_indexes(
f"Indexer has dimensions {idx.dims} that are different "
f"from that to be indexed along '{k}'"
)
data = as_compatible_data(idx)
data: T_DuckArray = as_compatible_data(idx)
pd_idx = safe_cast_to_index(data)
pd_idx.name = k
if isinstance(pd_idx, pd.MultiIndex):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7481,7 +7481,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset:
else:
variables[k] = f(v, *args, **kwargs)
if keep_attrs:
variables[k].attrs = v._attrs
variables[k]._attrs = v._attrs
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, attrs=attrs)

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
T_ChunkedArray = TypeVar("T_ChunkedArray")

if TYPE_CHECKING:
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -257,7 +257,7 @@ def normalize_chunks(

@abstractmethod
def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> T_ChunkedArray:
"""
Create a chunked array from a non-chunked numpy-like array.
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def copy(
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
T_Alignable = TypeVar("T_Alignable", bound="Alignable")

# Temporary placeholder for indicating an array api compliant type.
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any)

ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
import pandas as pd

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]:
return isinstance(value, (list, tuple))


def is_duck_array(value: Any) -> bool:
def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]:
if isinstance(value, np.ndarray):
return True
return (
Expand Down
76 changes: 45 additions & 31 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -66,6 +66,7 @@
PadModeOptions,
PadReflectOptions,
QuantileMethods,
T_DuckArray,
T_Variable,
)

Expand All @@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def as_variable(obj, name=None) -> Variable | IndexVariable:
def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable:
"""Convert an object into a Variable.

Parameters
Expand Down Expand Up @@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
elif isinstance(obj, (set, dict)):
raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
elif name is not None:
data = as_compatible_data(obj)
data: T_DuckArray = as_compatible_data(obj)
if data.ndim != 1:
raise MissingDimensionsError(
f"cannot set variable {name!r} with {data.ndim!r}-dimensional data "
Expand Down Expand Up @@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
return data


def as_compatible_data(data, fastpath: bool = False):
def as_compatible_data(
data: T_DuckArray | ArrayLike, fastpath: bool = False
) -> T_DuckArray:
"""Prepare and wrap data to put in a Variable.

- If data does not have the necessary attributes, convert it to ndarray.
Expand All @@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
"""
if fastpath and getattr(data, "ndim", 0) > 0:
# can't use fastpath (yet) for scalars
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

from xarray.core.dataarray import DataArray

Expand All @@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
data = _possibly_convert_datetime_or_timedelta_index(data)
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

if isinstance(data, tuple):
data = utils.to_0d_object_array(data)
Expand All @@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data
return cast("T_DuckArray", data)

# validate whether the data is valid data types.
data = np.asarray(data)
Expand Down Expand Up @@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):

__slots__ = ("_dims", "_data", "_attrs", "_encoding")

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self,
dims,
data: T_DuckArray | ArrayLike,
attrs=None,
encoding=None,
fastpath=False,
):
"""
Parameters
----------
Expand All @@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
Well-behaved code to serialize a Variable should ignore
unrecognized encoding items.
"""
self._data = as_compatible_data(data, fastpath=fastpath)
self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath)
self._dims = self._parse_dimensions(dims)
self._attrs = None
self._attrs: dict[Any, Any] | None = None
self._encoding = None
if attrs is not None:
self.attrs = attrs
Expand Down Expand Up @@ -410,7 +420,7 @@ def _in_memory(self):
)

@property
def data(self) -> Any:
def data(self: T_Variable):
"""
The Variable's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
Expand All @@ -429,12 +439,12 @@ def data(self) -> Any:
return self.values

@data.setter
def data(self, data):
def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None:
data = as_compatible_data(data)
if data.shape != self.shape:
if data.shape != self.shape: # type: ignore[attr-defined]
raise ValueError(
f"replacement data must match the Variable's shape. "
f"replacement data has shape {data.shape}; Variable has shape {self.shape}"
f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined]
)
self._data = data

Expand Down Expand Up @@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
return self._replace(encoding={})

def copy(
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None
) -> T_Variable:
"""Returns a copy of this object.

Expand Down Expand Up @@ -1058,24 +1068,26 @@ def copy(
def _copy(
self: T_Variable,
deep: bool = True,
data: ArrayLike | None = None,
data: T_DuckArray | ArrayLike | None = None,
memo: dict[int, Any] | None = None,
) -> T_Variable:
if data is None:
ndata = self._data
data_old = self._data

if isinstance(ndata, indexing.MemoryCachedArray):
if isinstance(data_old, indexing.MemoryCachedArray):
# don't share caching between copies
ndata = indexing.MemoryCachedArray(ndata.array)
ndata = indexing.MemoryCachedArray(data_old.array)
else:
ndata = data_old

if deep:
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
if self.shape != ndata.shape: # type: ignore[attr-defined]
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}"
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
)

attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
Expand Down Expand Up @@ -1248,11 +1260,11 @@ def chunk(
inline_array=inline_array,
)

data = self._data
if chunkmanager.is_chunked_array(data):
data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type]
data_old = self._data
if chunkmanager.is_chunked_array(data_old):
data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
else:
if isinstance(data, indexing.ExplicitlyIndexed):
if isinstance(data_old, indexing.ExplicitlyIndexed):
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
# that can't handle general array indexing. For example, in netCDF4 you
# can do "outer" indexing along two dimensions independent, which works
Expand All @@ -1261,20 +1273,22 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
data = indexing.ImplicitToExplicitIndexingAdapter(
data, indexing.OuterIndexer
ndata = indexing.ImplicitToExplicitIndexingAdapter(
data_old, indexing.OuterIndexer
)
else:
ndata = data_old

if utils.is_dict_like(chunks):
chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape))
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))

data = chunkmanager.from_array(
data,
data_chunked = chunkmanager.from_array(
ndata,
chunks, # type: ignore[arg-type]
**_from_array_kwargs,
)

return self._replace(data=data)
return self._replace(data=data_chunked)

def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
guess_chunkmanager,
list_chunkmanagers,
)
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
from xarray.tests import has_dask, requires_dask


Expand Down Expand Up @@ -76,7 +76,7 @@ def normalize_chunks(
return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)

def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> DummyChunkedArray:
from dask import array as da

Expand Down
12 changes: 7 additions & 5 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from copy import copy, deepcopy
from datetime import datetime, timedelta
from textwrap import dedent
from typing import Generic

import numpy as np
import pandas as pd
Expand All @@ -26,6 +27,7 @@
VectorizedIndexer,
)
from xarray.core.pycompat import array_type
from xarray.core.types import T_DuckArray
from xarray.core.utils import NDArrayMixin
from xarray.core.variable import as_compatible_data, as_variable
from xarray.tests import (
Expand Down Expand Up @@ -2529,7 +2531,7 @@ def test_to_index_variable_copy(self) -> None:
assert a.dims == ("x",)


class TestAsCompatibleData:
class TestAsCompatibleData(Generic[T_DuckArray]):
def test_unchanged_types(self):
types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray)
for t in types:
Expand Down Expand Up @@ -2610,17 +2612,17 @@ def test_tz_datetime(self) -> None:
times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
actual = as_compatible_data(times_s)
actual: T_DuckArray = as_compatible_data(times_s)
assert actual.array == times_s
assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz)

series = pd.Series(times_s)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
actual = as_compatible_data(series)
actual2: T_DuckArray = as_compatible_data(series)

np.testing.assert_array_equal(actual, series.values)
assert actual.dtype == np.dtype("datetime64[ns]")
np.testing.assert_array_equal(actual2, series.values)
assert actual2.dtype == np.dtype("datetime64[ns]")

def test_full_like(self) -> None:
# For more thorough tests, see test_variable.py
Expand Down
Loading