Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getitem to array protocol #8406

Merged
merged 12 commits into from
Dec 12, 2023
40 changes: 37 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Default(Enum):
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


_dtype = np.dtype
_DType = TypeVar("_DType", bound=np.dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
Expand Down Expand Up @@ -69,9 +69,16 @@ def dtype(self) -> _DType_co:
_Dims = tuple[_Dim, ...]

_DimsLike = Union[str, Iterable[_Dim]]
_AttrsLike = Union[Mapping[Any, Any], None]

_dtype = np.dtype
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# TODO: np.array_api was bugged and didn't allow (None,), but should!
# https://github.com/numpy/numpy/pull/25022
# https://github.com/data-apis/array-api/pull/674
_IndexKey = Union[int, slice, "ellipsis"]
_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
_IndexKeyLike = Union[_IndexKey, _IndexKeys]

_AttrsLike = Union[Mapping[Any, Any], None]


class _SupportsReal(Protocol[_T_co]):
Expand Down Expand Up @@ -113,6 +120,25 @@ class _arrayfunction(
Corresponds to np.ndarray.
"""

@overload
def __getitem__(
self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
) -> _arrayfunction[Any, _DType_co]:
...

@overload
def __getitem__(self, key: _IndexKeyLike, /) -> Any:
...

def __getitem__(
self,
key: _IndexKeyLike
| _arrayfunction[Any, Any]
| tuple[_arrayfunction[Any, Any], ...],
/,
) -> _arrayfunction[Any, _DType_co] | Any:
...

@overload
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
...
Expand Down Expand Up @@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
Corresponds to np.ndarray.
"""

def __getitem__(
self,
key: _IndexKeyLike
| Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
/,
) -> _arrayapi[Any, Any]:
...

def __array_namespace__(self) -> ModuleType:
...

Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_AttrsLike,
_DimsLike,
_DType,
_IndexKeyLike,
_Shape,
duckarray,
)
Expand Down Expand Up @@ -58,6 +59,19 @@ class CustomArrayIndexable(
ExplicitlyIndexed,
Generic[_ShapeType_co, _DType_co],
):
def __getitem__(
self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
) -> CustomArrayIndexable[Any, _DType_co]:
if isinstance(key, CustomArrayIndexable):
if isinstance(key.array, type(self.array)):
# TODO: key.array is duckarray here, can it be narrowed down further?
# an _arrayapi cannot be used on a _arrayfunction for example.
return type(self)(array=self.array[key.array]) # type: ignore[index]
else:
raise TypeError("key must have the same array type as self")
else:
return type(self)(array=self.array[key])

def __array_namespace__(self) -> ModuleType:
return np

Expand Down
Loading