Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert indexes.py to use Self for typing #8217

Merged
merged 9 commits into from
Sep 20, 2023
63 changes: 32 additions & 31 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

if TYPE_CHECKING:
from xarray.core.types import ErrorOptions, JoinOptions, T_Index
from xarray.core.types import ErrorOptions, JoinOptions, Self
from xarray.core.variable import Variable


Expand Down Expand Up @@ -60,11 +60,11 @@ class Index:

@classmethod
def from_variables(
cls: type[T_Index],
cls,
variables: Mapping[Any, Variable],
*,
options: Mapping[str, Any],
) -> T_Index:
) -> Self:
"""Create a new index object from one or more coordinate variables.

This factory method must be implemented in all subclasses of Index.
Expand All @@ -88,11 +88,11 @@ def from_variables(

@classmethod
def concat(
cls: type[T_Index],
indexes: Sequence[T_Index],
cls,
indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
) -> T_Index:
) -> Self:
"""Create a new index by concatenating one or more indexes of the same
type.

Expand Down Expand Up @@ -120,9 +120,7 @@ def concat(
raise NotImplementedError()

@classmethod
def stack(
cls: type[T_Index], variables: Mapping[Any, Variable], dim: Hashable
) -> T_Index:
def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self:
"""Create a new index by stacking coordinate variables into a single new
dimension.

Expand Down Expand Up @@ -208,8 +206,8 @@ def to_pandas_index(self) -> pd.Index:
raise TypeError(f"{self!r} cannot be cast to a pandas.Index object")

def isel(
self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
) -> T_Index | None:
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
) -> Self | None:
"""Maybe returns a new index from the current index itself indexed by
positional indexers.

Expand Down Expand Up @@ -264,7 +262,7 @@ def sel(self, labels: dict[Any, Any]) -> IndexSelResult:
"""
raise NotImplementedError(f"{self!r} doesn't support label-based selection")

def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
def join(self, other: Self, how: JoinOptions = "inner") -> Self:
"""Return a new index from the combination of this index with another
index of the same type.

Expand All @@ -286,7 +284,7 @@ def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
f"{self!r} doesn't support alignment with inner/outer join method"
)

def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:
def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""Query the index with another index of the same type.

Implementation is optional but required in order to support alignment.
Expand All @@ -304,7 +302,7 @@ def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:
"""
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")

def equals(self: T_Index, other: T_Index) -> bool:
def equals(self, other: Self) -> bool:
"""Compare this index with another index of the same type.

Implementation is optional but required in order to support alignment.
Expand All @@ -321,7 +319,7 @@ def equals(self: T_Index, other: T_Index) -> bool:
"""
raise NotImplementedError()

def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None:
def roll(self, shifts: Mapping[Any, int]) -> Self | None:
"""Roll this index by an offset along one or more dimensions.

This method can be re-implemented in subclasses of Index, e.g., when the
Expand All @@ -347,10 +345,10 @@ def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None:
return None

def rename(
self: T_Index,
self,
name_dict: Mapping[Any, Hashable],
dims_dict: Mapping[Any, Hashable],
) -> T_Index:
) -> Self:
"""Maybe update the index with new coordinate and dimension names.

This method should be re-implemented in subclasses of Index if it has
Expand All @@ -377,7 +375,7 @@ def rename(
"""
return self

def copy(self: T_Index, deep: bool = True) -> T_Index:
def copy(self, deep: bool = True) -> Self:
"""Return a (deep) copy of this index.

Implementation in subclasses of Index is optional. The base class
Expand All @@ -396,15 +394,13 @@ def copy(self: T_Index, deep: bool = True) -> T_Index:
"""
return self._copy(deep=deep)

def __copy__(self: T_Index) -> T_Index:
def __copy__(self) -> Self:
return self.copy(deep=False)

def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index:
return self._copy(deep=True, memo=memo)

def _copy(
self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None
) -> T_Index:
def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self:
cls = self.__class__
copied = cls.__new__(cls)
if deep:
Expand All @@ -414,7 +410,7 @@ def _copy(
copied.__dict__.update(self.__dict__)
return copied

def __getitem__(self: T_Index, indexer: Any) -> T_Index:
def __getitem__(self, indexer: Any) -> Self:
raise NotImplementedError()

def _repr_inline_(self, max_width):
Expand Down Expand Up @@ -674,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index:
@classmethod
def concat(
cls,
indexes: Sequence[PandasIndex],
indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
) -> PandasIndex:
) -> Self:
new_pd_index = cls._concat_indexes(indexes, dim, positions)

if not indexes:
Expand Down Expand Up @@ -800,7 +796,11 @@ def equals(self, other: Index):
return False
return self.index.equals(other.index) and self.dim == other.dim

def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex:
def join(
self,
other: Self,
how: str = "inner",
) -> Self:
if how == "outer":
index = self.index.union(other.index)
else:
Expand All @@ -811,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd
return type(self)(index, self.dim, coord_dtype=coord_dtype)

def reindex_like(
self, other: PandasIndex, method=None, tolerance=None
self, other: Self, method=None, tolerance=None
) -> dict[Hashable, Any]:
if not self.index.is_unique:
raise ValueError(
Expand Down Expand Up @@ -963,12 +963,12 @@ def from_variables(
return obj

@classmethod
def concat( # type: ignore[override]
def concat(
cls,
indexes: Sequence[PandasMultiIndex],
indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
) -> PandasMultiIndex:
) -> Self:
new_pd_index = cls._concat_indexes(indexes, dim, positions)

if not indexes:
Expand Down Expand Up @@ -1602,7 +1602,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]:
return Indexes(indexes, self._variables, index_type=pd.Index)

def copy_indexes(
self, deep: bool = True, memo: dict[int, Any] | None = None
self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None
) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]:
"""Return a new dictionary with copies of indexes, preserving
unique indexes.
Expand All @@ -1619,6 +1619,7 @@ def copy_indexes(
new_indexes = {}
new_index_vars = {}

idx: T_PandasOrXarrayIndex
for idx, coords in self.group_by_index():
if isinstance(idx, pd.Index):
convert_new_idx = True
Expand Down
Loading