Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: interval.pyi #44922

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions pandas/_libs/interval.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from datetime import timedelta
import numbers
from typing import (
Any,
Generic,
TypeVar,
overload,
)

import numpy as np
import numpy.typing as npt

from pandas._typing import IntervalBound

from pandas import (
Timedelta,
Timestamp,
)

_OrderableMixinT = TypeVar(
"_OrderableMixinT", int, float, Timestamp, Timedelta, npt.NDArray[np.generic]
)
_OrderableT = TypeVar("_OrderableT", int, float, Timestamp, Timedelta)

# note: mypy doesn't support overloading properties
# based on github.com/microsoft/python-type-stubs/pull/167
class _LengthProperty:
@overload
def __get__(self, instance: IntervalMixin[Timestamp], owner: Any) -> Timedelta: ...
@overload
def __get__(
self, instance: IntervalMixin[_OrderableMixinT], owner: Any
) -> _OrderableMixinT: ...

class IntervalMixin(Generic[_OrderableMixinT]):
@property
def closed_left(self) -> bool: ...
@property
def closed_right(self) -> bool: ...
@property
def open_left(self) -> bool: ...
@property
def open_right(self) -> bool: ...
@property
def mid(self) -> _OrderableT: ...
length: _LengthProperty
@property
def is_empty(self) -> bool: ...
def _check_closed_matches(self, other: IntervalMixin, name: str = ...) -> None: ...

class Interval(IntervalMixin[_OrderableT]):
def __init__(
self,
Dr-Irv marked this conversation as resolved.
Show resolved Hide resolved
left: _OrderableT,
right: _OrderableT,
closed: IntervalBound = ...,
) -> None: ...
@property
def closed(self) -> str: ...
@property
def left(self) -> _OrderableT: ...
@property
def right(self) -> _OrderableT: ...
def __str__(self) -> str: ...
# TODO: could return Interval with different type
Dr-Irv marked this conversation as resolved.
Show resolved Hide resolved
def __add__(
self, y: numbers.Number | np.timedelta64 | timedelta
Copy link
Contributor

@Dr-Irv Dr-Irv Feb 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to make the operators type specific. For Interval[Timestamp], you can only add and subtract Timedelta. For the numeric ones, you can only add/subtract/multiply/divide float or int. See the latest in microsoft/python-type-stubs#167

) -> Interval[_OrderableT]: ...
def __radd__(
self, y: numbers.Number | np.timedelta64 | timedelta
) -> Interval[_OrderableT]: ...
def __sub__(
self, y: numbers.Number | np.timedelta64 | timedelta
) -> Interval[_OrderableT]: ...
def __mul__(self, y: numbers.Number) -> Interval[_OrderableT]: ...
Copy link
Contributor

@Dr-Irv Dr-Irv Feb 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiply and divide don't apply for the Timestamp intervals

def __rmul__(self, y: numbers.Number) -> Interval[_OrderableT]: ...
def __truediv__(self, y: numbers.Number) -> Interval[_OrderableT]: ...
def __floordiv__(self, y: numbers.Number) -> Interval[_OrderableT]: ...
def __hash__(self) -> int: ...
def overlaps(self, other: Interval[_OrderableT]) -> bool: ...

VALID_CLOSED: frozenset[str]

# takes npt.NDArray[Interval[_OrderableT]] and returns arrays of type
# _OrderableT but _Orderable is not a valid dtype
def intervals_to_interval_bounds(
intervals: npt.NDArray[np.object_], validate_closed: bool = ...
) -> tuple[np.ndarray, np.ndarray, str]: ...

# from pandas/_libs/intervaltree.pxi.in
_GenericT = TypeVar("_GenericT", bound=np.generic)

# error: Value of type variable "_OrderableMixinT" of "IntervalMixin"
# cannot be "ndarray"
class IntervalTree(
Generic[_GenericT],
IntervalMixin[npt.NDArray[_GenericT]], # type: ignore[type-var]
):
_na_count: int
def __init__(
self,
left: npt.NDArray[_GenericT],
right: npt.NDArray[_GenericT],
closed: IntervalBound = ...,
leaf_size: int = ...,
) -> None: ...
@property
def left_sorter(self) -> npt.NDArray[_GenericT]: ...
@property
def right_sorter(self) -> npt.NDArray[_GenericT]: ...
@property
def is_overlapping(self) -> bool: ...
@property
def is_monotonic_increasing(self) -> bool: ...
def get_indexer(self, target: np.ndarray) -> npt.NDArray[np.intp]: ...
def get_indexer_non_unique(
self, target: np.ndarray
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ...
def __repr__(self) -> str: ...
def clear_mapping(self) -> None: ...
6 changes: 4 additions & 2 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@
PythonScalar = Union[str, int, float, bool]
DatetimeLikeScalar = Union["Period", "Timestamp", "Timedelta"]
PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"]
Scalar = Union[PythonScalar, PandasScalar]
Scalar = Union[PythonScalar, PandasScalar, np.datetime64, np.timedelta64]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, Interval resolved to Any. Some functions that are supposed to return a Scalar can return np.datetime64 / np.timedelta64 which was previously unnoticed.

IntStrT = TypeVar("IntStrT", int, str)


# timestamp and timedelta convertible types

TimestampConvertibleTypes = Union[
Expand Down Expand Up @@ -304,3 +303,6 @@ def closed(self) -> bool:

# read_xml parsers
XMLParsers = Literal["lxml", "etree"]

# on which side(s) Interval is closed
IntervalBound = Literal["left", "right", "both", "neither"]
2 changes: 1 addition & 1 deletion pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def factorize(
else:
dtype = values.dtype
values = _ensure_data(values)
na_value: Scalar
na_value: Scalar | None
twoertwein marked this conversation as resolved.
Show resolved Hide resolved

if original.dtype.kind in ["m", "M"]:
# Note: factorize_array will cast NaT bc it has a __int__
Expand Down
28 changes: 17 additions & 11 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pandas._typing import (
ArrayLike,
Dtype,
IntervalBound,
NpDtype,
PositionalIndexer,
ScalarIndexer,
Expand Down Expand Up @@ -196,6 +197,9 @@ class IntervalArray(IntervalMixin, ExtensionArray):
ndim = 1
can_hold_na = True
_na_value = _fill_value = np.nan
_left: np.ndarray
_right: np.ndarray
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the type of _left and _right, see mypy error for _from_sequence in this file.

_dtype: IntervalDtype

# ---------------------------------------------------------------------
# Constructors
Expand Down Expand Up @@ -657,11 +661,7 @@ def __getitem__(
if is_scalar(left) and isna(left):
return self._fill_value
return Interval(left, right, self.closed)
# error: Argument 1 to "ndim" has incompatible type "Union[ndarray,
# ExtensionArray]"; expected "Union[Union[int, float, complex, str, bytes,
# generic], Sequence[Union[int, float, complex, str, bytes, generic]],
# Sequence[Sequence[Any]], _SupportsArray]"
if np.ndim(left) > 1: # type: ignore[arg-type]
if np.ndim(left) > 1:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return self._shallow_copy(left, right)
Expand Down Expand Up @@ -945,10 +945,10 @@ def _concat_same_type(
-------
IntervalArray
"""
closed = {interval.closed for interval in to_concat}
if len(closed) != 1:
closed_set = {interval.closed for interval in to_concat}
if len(closed_set) != 1:
raise ValueError("Intervals must all be closed on the same side.")
closed = closed.pop()
closed = closed_set.pop()

left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
Expand Down Expand Up @@ -1317,7 +1317,7 @@ def overlaps(self, other):
# ---------------------------------------------------------------------

@property
def closed(self):
def closed(self) -> IntervalBound:
"""
Whether the intervals are closed on the left-side, right-side, both or
neither.
Expand Down Expand Up @@ -1665,8 +1665,14 @@ def _from_combined(self, combined: np.ndarray) -> IntervalArray:

dtype = self._left.dtype
if needs_i8_conversion(dtype):
new_left = type(self._left)._from_sequence(nc[:, 0], dtype=dtype)
new_right = type(self._right)._from_sequence(nc[:, 1], dtype=dtype)
# error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence"
new_left = type(self._left)._from_sequence( # type: ignore[attr-defined]
nc[:, 0], dtype=dtype
)
# error: "Type[ndarray[Any, Any]]" has no attribute "_from_sequence"
new_right = type(self._right)._from_sequence( # type: ignore[attr-defined]
nc[:, 1], dtype=dtype
)
else:
new_left = nc[:, 0].view(dtype)
new_right = nc[:, 1].view(dtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def to_numpy(
self,
dtype: npt.DTypeLike | None = None,
copy: bool = False,
na_value: Scalar = lib.no_default,
na_value: Scalar | libmissing.NAType | lib.NoDefault = lib.no_default,
) -> np.ndarray:
"""
Convert to a NumPy Array.
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def _str_replace(
return type(self)(result)

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar = None
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if pa_version_under4p0:
return super()._str_match(pat, case, flags, na)
Expand All @@ -771,7 +771,9 @@ def _str_match(
pat = "^" + pat
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(self, pat, case: bool = True, flags: int = 0, na: Scalar = None):
def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if pa_version_under4p0:
return super()._str_fullmatch(pat, case, flags, na)

Expand Down
3 changes: 1 addition & 2 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ArrayLike,
NpDtype,
RandomState,
Scalar,
T,
)
from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -517,7 +516,7 @@ def f(x):


def convert_to_list_like(
values: Scalar | Iterable | AnyArrayLike,
values: Hashable | Iterable | AnyArrayLike,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All scalars should also be hashable.

) -> list | AnyArrayLike:
"""
Convert list-like or scalar input to list-like. List, numpy and pandas array-like
Expand Down
41 changes: 31 additions & 10 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Hashable,
Literal,
)

import numpy as np
Expand All @@ -28,6 +29,7 @@
from pandas._typing import (
Dtype,
DtypeObj,
IntervalBound,
npt,
)
from pandas.errors import InvalidIndexError
Expand Down Expand Up @@ -191,10 +193,12 @@ class IntervalIndex(ExtensionIndex):
_typ = "intervalindex"

# annotate properties pinned via inherit_names
closed: str
closed: IntervalBound
is_non_overlapping_monotonic: bool
closed_left: bool
closed_right: bool
open_left: bool
open_right: bool

_data: IntervalArray
_values: IntervalArray
Expand Down Expand Up @@ -246,7 +250,7 @@ def __new__(
def from_breaks(
cls,
breaks,
closed: str = "right",
closed: IntervalBound = "right",
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -277,7 +281,7 @@ def from_arrays(
cls,
left,
right,
closed: str = "right",
closed: IntervalBound = "right",
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -307,7 +311,7 @@ def from_arrays(
def from_tuples(
cls,
data,
closed: str = "right",
closed: IntervalBound = "right",
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand All @@ -318,8 +322,10 @@ def from_tuples(

# --------------------------------------------------------------------

# error: Return type "IntervalTree[Any]" of "_engine" incompatible with return type
# "IndexEngine" in supertype "Index"
@cache_readonly
def _engine(self) -> IntervalTree:
def _engine(self) -> IntervalTree: # type: ignore[override]
left = self._maybe_convert_i8(self.left)
right = self._maybe_convert_i8(self.right)
return IntervalTree(left, right, closed=self.closed)
Expand Down Expand Up @@ -511,7 +517,10 @@ def _maybe_convert_i8(self, key):
left = self._maybe_convert_i8(key.left)
right = self._maybe_convert_i8(key.right)
constructor = Interval if scalar else IntervalIndex.from_arrays
return constructor(left, right, closed=self.closed)
# error: "object" not callable
return constructor( # type: ignore[operator]
left, right, closed=self.closed
)

if scalar:
# Timestamp/Timedelta
Expand Down Expand Up @@ -543,7 +552,7 @@ def _maybe_convert_i8(self, key):

return key_i8

def _searchsorted_monotonic(self, label, side: str = "left"):
def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
if not self.is_non_overlapping_monotonic:
raise KeyError(
"can only get slices from an IntervalIndex if bounds are "
Expand Down Expand Up @@ -663,7 +672,9 @@ def _get_indexer(
# homogeneous scalar index: use IntervalTree
# we should always have self._should_partial_index(target) here
target = self._maybe_convert_i8(target)
indexer = self._engine.get_indexer(target.values)
# error: Argument 1 to "get_indexer" of "IntervalTree" has incompatible type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not add typing to _maybe_convert_i8 ?

# "Union[ExtensionArray, ndarray[Any, Any]]"; expected "ndarray[Any, Any]"
indexer = self._engine.get_indexer(target.values) # type: ignore[arg-type]
else:
# heterogeneous scalar index: defer elementwise to get_loc
# we should always have self._should_partial_index(target) here
Expand Down Expand Up @@ -698,7 +709,12 @@ def get_indexer_non_unique(
# Note: this case behaves differently from other Index subclasses
# because IntervalIndex does partial-int indexing
target = self._maybe_convert_i8(target)
indexer, missing = self._engine.get_indexer_non_unique(target.values)
# error: Argument 1 to "get_indexer_non_unique" of "IntervalTree" has
# incompatible type "Union[ExtensionArray, ndarray[Any, Any]]"; expected
# "ndarray[Any, Any]" [arg-type]
indexer, missing = self._engine.get_indexer_non_unique(
target.values # type: ignore[arg-type]
)

return ensure_platform_int(indexer), ensure_platform_int(missing)

Expand Down Expand Up @@ -941,7 +957,12 @@ def _is_type_compatible(a, b) -> bool:


def interval_range(
start=None, end=None, periods=None, freq=None, name: Hashable = None, closed="right"
start=None,
end=None,
periods=None,
freq=None,
name: Hashable = None,
closed: IntervalBound = "right",
) -> IntervalIndex:
"""
Return a fixed frequency IntervalIndex.
Expand Down
Loading