Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give NamedArray Generic dimension type #8276

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import math
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload

import numpy as np

Expand Down Expand Up @@ -39,8 +39,6 @@
PostPersistCallable: Any # type: ignore[no-redef]

# T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]")
DimsInput = Union[str, Iterable[Hashable]]
Dims = tuple[Hashable, ...]
AttrsInput = Union[Mapping[Any, Any], None]


Expand Down Expand Up @@ -75,7 +73,10 @@ def as_compatible_data(
return cast(T_DuckArray, np.asarray(data))


class NamedArray(Generic[T_DuckArray]):
T_Dim = TypeVar("T_Dim", bound=Hashable)


class NamedArray(Generic[T_Dim, T_DuckArray]):

"""A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array.
Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names,
Expand All @@ -84,20 +85,60 @@ class NamedArray(Generic[T_DuckArray]):
__slots__ = ("_data", "_dims", "_attrs")

_data: T_DuckArray
_dims: Dims
_dims: tuple[T_Dim, ...]
_attrs: dict[Any, Any] | None

@overload
def __init__(
self: NamedArray[str, T_DuckArray],
dims: str,
data: T_DuckArray,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[str, np.ndarray[Any, np.dtype[np.generic]]],
dims: str,
data: np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[T_Dim, T_DuckArray],
dims: Iterable[T_Dim],
data: T_DuckArray,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

@overload
def __init__(
self: NamedArray[T_Dim, np.ndarray[Any, np.dtype[np.generic]]],
dims: Iterable[T_Dim],
data: np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
) -> None:
...

def __init__(
self,
dims: DimsInput,
dims: str | Iterable[T_Dim],
data: T_DuckArray | np.typing.ArrayLike,
attrs: AttrsInput = None,
fastpath: bool = False,
):
) -> None:
"""
Parameters
----------
dims : str or iterable of str
dims : str or iterable of hashable
Name(s) of the dimension(s).
data : T_DuckArray or np.typing.ArrayLike
The actual data that populates the array. Should match the shape specified by `dims`.
Expand Down Expand Up @@ -194,22 +235,22 @@ def nbytes(self) -> int:
return self.size * self.dtype.itemsize

@property
def dims(self) -> Dims:
def dims(self) -> tuple[T_Dim, ...]:
"""Tuple of dimension names with which this NamedArray is associated."""
return self._dims

@dims.setter
def dims(self, value: DimsInput) -> None:
def dims(self, value: str | Iterable[T_Dim]) -> None:
self._dims = self._parse_dimensions(value)

def _parse_dimensions(self, dims: DimsInput) -> Dims:
dims = (dims,) if isinstance(dims, str) else tuple(dims)
if len(dims) != self.ndim:
def _parse_dimensions(self, dims: str | Iterable[T_Dim]) -> tuple[T_Dim, ...]:
pdims = (dims,) if isinstance(dims, str) else tuple(dims)
if len(pdims) != self.ndim:
raise ValueError(
f"dimensions {dims} must have the same length as the "
f"dimensions {pdims} must have the same length as the "
f"number of data dimensions, ndim={self.ndim}"
)
return dims
return pdims # type: ignore[return-value]

@property
def attrs(self) -> dict[Any, Any]:
Expand Down Expand Up @@ -397,7 +438,7 @@ def sizes(self) -> dict[Hashable, int]:

def _replace(
self,
dims: DimsInput | Default = _default,
dims: str | Iterable[T_Dim] | Default = _default,
data: T_DuckArray | np.typing.ArrayLike | Default = _default,
attrs: AttrsInput | Default = _default,
) -> Self:
Expand Down
16 changes: 8 additions & 8 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed):

def test_properties() -> None:
data = 0.5 * np.arange(10).reshape(2, 5)
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y"], data, {"key": "value"})
assert named_array.dims == ("x", "y")
assert np.array_equal(named_array.data, data)
Expand All @@ -104,7 +104,7 @@ def test_properties() -> None:


def test_attrs() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5))
assert named_array.attrs == {}
named_array.attrs["key"] = "value"
Expand All @@ -114,7 +114,7 @@ def test_attrs() -> None:


def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(["x", "y", "z"], random_inputs)
assert np.array_equal(named_array.data, random_inputs)
with pytest.raises(ValueError):
Expand All @@ -130,7 +130,7 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
],
)
def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], data)
assert named_array.data == data
assert named_array.dims == ()
Expand All @@ -142,7 +142,7 @@ def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:


def test_0d_object() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], (10, 12, 12))
expected_data = np.empty((), dtype=object)
expected_data[()] = (10, 12, 12)
Expand All @@ -157,7 +157,7 @@ def test_0d_object() -> None:


def test_0d_datetime() -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray([], np.datetime64("2000-01-01"))
assert named_array.dtype == np.dtype("datetime64[D]")

Expand All @@ -179,7 +179,7 @@ def test_0d_datetime() -> None:
def test_0d_timedelta(
timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
) -> None:
named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]]
named_array: NamedArray[str, np.ndarray[Any, np.dtype[np.timedelta64]]]
named_array = NamedArray([], timedelta)
assert named_array.dtype == expected_dtype
assert named_array.data == timedelta
Expand All @@ -196,7 +196,7 @@ def test_0d_timedelta(
],
)
def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None:
named_array: NamedArray[np.ndarray[Any, Any]]
named_array: NamedArray[str, np.ndarray[Any, Any]]
named_array = NamedArray(dims, np.random.random(data_shape))
assert named_array.dims == tuple(dims)
if raises:
Expand Down
Loading