Skip to content

Commit

Permalink
TYP: Typing improvements for Index (#59105)
Browse files Browse the repository at this point in the history
* Typing improvements for Index

* better numpy type hints for Index.delete

* replace some hints with literals, move slice_type to _typing.py
  • Loading branch information
AndreyKolomiets authored Jul 21, 2024
1 parent 18a3eec commit 080add1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,5 @@ def closed(self) -> bool:

# maintaine the sub-type of any hashable sequence
SequenceT = TypeVar("SequenceT", bound=Sequence[Hashable])

SliceType = Optional[Hashable]
41 changes: 30 additions & 11 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
ArrayLike,
Axes,
Axis,
AxisInt,
DropKeep,
Dtype,
DtypeObj,
F,
IgnoreRaise,
Expand All @@ -57,6 +59,7 @@
ReindexMethod,
Self,
Shape,
SliceType,
npt,
)
from pandas.compat.numpy import function as nv
Expand Down Expand Up @@ -1087,7 +1090,7 @@ def view(self, cls=None):
result._id = self._id
return result

def astype(self, dtype, copy: bool = True):
def astype(self, dtype: Dtype, copy: bool = True):
"""
Create an Index with values cast to dtypes.
Expand Down Expand Up @@ -2957,7 +2960,7 @@ def _dti_setop_align_tzs(self, other: Index, setop: str_t) -> tuple[Index, Index
return self, other

@final
def union(self, other, sort=None):
def union(self, other, sort: bool | None = None):
"""
Form the union of two Index objects.
Expand Down Expand Up @@ -3334,7 +3337,7 @@ def _intersection_via_get_indexer(
return result

@final
def difference(self, other, sort=None):
def difference(self, other, sort: bool | None = None):
"""
Return a new Index with elements of index not in `other`.
Expand Down Expand Up @@ -3420,7 +3423,12 @@ def _wrap_difference_result(self, other, result):
# We will override for MultiIndex to handle empty results
return self._wrap_setop_result(other, result)

def symmetric_difference(self, other, result_name=None, sort=None):
def symmetric_difference(
self,
other,
result_name: abc.Hashable | None = None,
sort: bool | None = None,
):
"""
Compute the symmetric difference of two Index objects.
Expand Down Expand Up @@ -6389,7 +6397,7 @@ def _transform_index(self, func, *, level=None) -> Index:
items = [func(x) for x in self]
return Index(items, name=self.name, tupleize_cols=False)

def isin(self, values, level=None) -> npt.NDArray[np.bool_]:
def isin(self, values, level: str_t | int | None = None) -> npt.NDArray[np.bool_]:
"""
Return a boolean array where the index values are in `values`.
Expand Down Expand Up @@ -6687,7 +6695,12 @@ def get_slice_bound(self, label, side: Literal["left", "right"]) -> int:
else:
return slc

def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]:
def slice_locs(
self,
start: SliceType = None,
end: SliceType = None,
step: int | None = None,
) -> tuple[int, int]:
"""
Compute slice locations for input labels.
Expand Down Expand Up @@ -6781,7 +6794,9 @@ def slice_locs(self, start=None, end=None, step=None) -> tuple[int, int]:

return start_slice, end_slice

def delete(self, loc) -> Self:
def delete(
self, loc: int | np.integer | list[int] | npt.NDArray[np.integer]
) -> Self:
"""
Make new Index with passed location(-s) deleted.
Expand Down Expand Up @@ -7227,7 +7242,9 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None:
raise TypeError(f"cannot perform {opname} with {type(self).__name__}")

@Appender(IndexOpsMixin.argmin.__doc__)
def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
def argmin(
self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs
) -> int:
nv.validate_argmin(args, kwargs)
nv.validate_minmax_axis(axis)

Expand All @@ -7240,7 +7257,9 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
return super().argmin(skipna=skipna)

@Appender(IndexOpsMixin.argmax.__doc__)
def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
def argmax(
self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs
) -> int:
nv.validate_argmax(args, kwargs)
nv.validate_minmax_axis(axis)

Expand All @@ -7251,7 +7270,7 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
raise ValueError("Encountered all NA values")
return super().argmax(skipna=skipna)

def min(self, axis=None, skipna: bool = True, *args, **kwargs):
def min(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs):
"""
Return the minimum value of the Index.
Expand Down Expand Up @@ -7314,7 +7333,7 @@ def min(self, axis=None, skipna: bool = True, *args, **kwargs):

return nanops.nanmin(self._values, skipna=skipna)

def max(self, axis=None, skipna: bool = True, *args, **kwargs):
def max(self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs):
"""
Return the maximum value of the Index.
Expand Down

0 comments on commit 080add1

Please sign in to comment.