diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 670a2076eb1..37832daca58 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -29,7 +29,7 @@ class Default(Enum): _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) - +_dtype = np.dtype _DType = TypeVar("_DType", bound=np.dtype[Any]) _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) # A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` @@ -69,9 +69,16 @@ def dtype(self) -> _DType_co: _Dims = tuple[_Dim, ...] _DimsLike = Union[str, Iterable[_Dim]] -_AttrsLike = Union[Mapping[Any, Any], None] -_dtype = np.dtype +# https://data-apis.org/array-api/latest/API_specification/indexing.html +# TODO: np.array_api was bugged and didn't allow (None,), but should! +# https://github.com/numpy/numpy/pull/25022 +# https://github.com/data-apis/array-api/pull/674 +_IndexKey = Union[int, slice, "ellipsis"] +_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...] +_IndexKeyLike = Union[_IndexKey, _IndexKeys] + +_AttrsLike = Union[Mapping[Any, Any], None] class _SupportsReal(Protocol[_T_co]): @@ -113,6 +120,25 @@ class _arrayfunction( Corresponds to np.ndarray. """ + @overload + def __getitem__( + self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], / + ) -> _arrayfunction[Any, _DType_co]: + ... + + @overload + def __getitem__(self, key: _IndexKeyLike, /) -> Any: + ... + + def __getitem__( + self, + key: _IndexKeyLike + | _arrayfunction[Any, Any] + | tuple[_arrayfunction[Any, Any], ...], + /, + ) -> _arrayfunction[Any, _DType_co] | Any: + ... + @overload def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: ... @@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType Corresponds to np.ndarray. """ + def __getitem__( + self, + key: _IndexKeyLike + | Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]] + /, + ) -> _arrayapi[Any, Any]: + ... + def __array_namespace__(self) -> ModuleType: ... diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index c75b01e9e50..9aedcbc80d4 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -28,6 +28,7 @@ _AttrsLike, _DimsLike, _DType, + _IndexKeyLike, _Shape, duckarray, ) @@ -58,6 +59,19 @@ class CustomArrayIndexable( ExplicitlyIndexed, Generic[_ShapeType_co, _DType_co], ): + def __getitem__( + self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], / + ) -> CustomArrayIndexable[Any, _DType_co]: + if isinstance(key, CustomArrayIndexable): + if isinstance(key.array, type(self.array)): + # TODO: key.array is duckarray here, can it be narrowed down further? + # an _arrayapi cannot be used on a _arrayfunction for example. + return type(self)(array=self.array[key.array]) # type: ignore[index] + else: + raise TypeError("key must have the same array type as self") + else: + return type(self)(array=self.array[key]) + def __array_namespace__(self) -> ModuleType: return np