Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: Type hints & assert statements #42044

Closed
wants to merge 9 commits into from
4 changes: 3 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Dict,
Hashable,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand All @@ -37,7 +38,6 @@
# https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
if TYPE_CHECKING:
from typing import (
Literal,
TypedDict,
final,
)
Expand Down Expand Up @@ -123,6 +123,8 @@
Frequency = Union[str, "DateOffset"]
Axes = Collection[Any]
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]
MergeTypes = Literal["inner", "outer", "left", "right", "cross"]
ConcatTypes = Literal["inner", "outer"]

# dtypes
NpDtype = Union[str, np.dtype]
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Hashable,
Iterator,
List,
Literal,
cast,
)
import warnings
Expand Down Expand Up @@ -518,7 +519,10 @@ def apply_multiple(self) -> FrameOrSeriesUnion:
return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs)

def normalize_dictlike_arg(
self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict
self,
how: Literal["apply", "agg", "transform"],
obj: FrameOrSeriesUnion,
func: AggFuncTypeDict,
) -> AggFuncTypeDict:
"""
Handler for dict-like argument.
Expand All @@ -527,7 +531,10 @@ def normalize_dictlike_arg(
that a nested renamer is not passed. Also normalizes to all lists
when values consists of a mix of list and non-lists.
"""
assert how in ("apply", "agg", "transform")
if how not in ("apply", "agg", "transform"):
raise ValueError(
"Value for how argument must be one of : apply, agg, transform"
)

# Can't use func.values(); wouldn't work for a Series
if (
Expand Down
7 changes: 5 additions & 2 deletions pandas/core/arrays/_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
from __future__ import annotations

from typing import Literal

import numpy as np

from pandas._libs.lib import i8max
Expand Down Expand Up @@ -75,7 +77,7 @@ def generate_regular_range(


def _generate_range_overflow_safe(
endpoint: int, periods: int, stride: int, side: str = "start"
endpoint: int, periods: int, stride: int, side: Literal["start", "end"] = "start"
) -> int:
"""
Calculate the second endpoint for passing to np.arange, checking
Expand Down Expand Up @@ -142,13 +144,14 @@ def _generate_range_overflow_safe(


def _generate_range_overflow_safe_signed(
endpoint: int, periods: int, stride: int, side: str
endpoint: int, periods: int, stride: int, side: Literal["start", "end"]
) -> int:
"""
A special case for _generate_range_overflow_safe where `periods * stride`
can be calculated without overflowing int64 bounds.
"""
assert side in ["start", "end"]

if side == "end":
stride *= -1

Expand Down
26 changes: 15 additions & 11 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
to_offset,
tzconversion,
)
from pandas._typing import Dtype
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.cast import astype_dt64_to_dt64tz
Expand Down Expand Up @@ -1967,12 +1968,12 @@ def sequence_to_datetimes(

def sequence_to_dt64ns(
data,
dtype=None,
copy=False,
tz=None,
dayfirst=False,
yearfirst=False,
ambiguous="raise",
dtype: Dtype | None = None,
copy: bool = False,
tz: tzinfo | str = None,
dayfirst: bool = False,
yearfirst: bool = False,
ambiguous: str | bool = "raise",
*,
allow_object: bool = False,
allow_mixed: bool = False,
Expand Down Expand Up @@ -2126,10 +2127,10 @@ def sequence_to_dt64ns(

def objects_to_datetime64ns(
data: np.ndarray,
dayfirst,
yearfirst,
utc=False,
errors="raise",
dayfirst: bool,
yearfirst: bool,
utc: bool = False,
errors: Literal["raise", "coerce", "ignore"] = "raise",
require_iso8601: bool = False,
allow_object: bool = False,
allow_mixed: bool = False,
Expand Down Expand Up @@ -2164,7 +2165,10 @@ def objects_to_datetime64ns(
------
ValueError : if data cannot be converted to datetimes
"""
assert errors in ["raise", "ignore", "coerce"]
if errors not in ["raise", "ignore", "coerce"]:
raise ValueError(
"Value for errors argument must be one of: raise, coerce, ignore"
)

# if str-dtype, convert
data = np.array(data, copy=False, dtype=np.object_)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9191,7 +9191,7 @@ def merge(
sort: bool = False,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure about this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per _MergeOperation you can optionally supply a string as the column name, otherwise the indicator is given a default name:

if isinstance(self.indicator, str):
self.indicator_name = self.indicator
elif isinstance(self.indicator, bool):
self.indicator_name = "_merge" if self.indicator else None

Maybe I missed this, but is there a "valid column name" definition that would be more specific than str?

validate: str | None = None,
) -> DataFrame:
from pandas.core.reshape.merge import merge
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5691,7 +5691,9 @@ def _validate_indexer(self, form: str_t, key, kind: str_t):
if key is not None and not is_integer(key):
raise self._invalid_indexer(form, key)

def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default):
def _maybe_cast_slice_bound(
self, label, side: str_t, kind: Literal["loc", "getitem"] = no_default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think needs to be amended to include no_default?

):
"""
This function should be overloaded in subclasses that allow non-trivial
casting on label-slice bounds, e.g. datetime-like indices allowing
Expand Down Expand Up @@ -5755,7 +5757,10 @@ def get_slice_bound(self, label, side: str_t, kind=None) -> int:
int
Index of label.
"""
assert kind in ["loc", "getitem", None]
if kind not in ["loc", "getitem", None]:
raise ValueError(
"Value for kind argument must be one of: loc, getitem or None"
)

if side not in ("left", "right"):
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,8 @@ def where(self, other, cond, errors="raise") -> list[Block]:
assert cond.ndim == self.ndim
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))

assert errors in ["raise", "ignore"]
if errors not in ["raise", "ignore"]:
raise ValueError("Value for errors argument must be one of: raise, ignore")
transpose = self.ndim == 2

values = self.values
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
Literal,
cast,
)

Expand Down Expand Up @@ -164,7 +165,7 @@ def clean_interp_method(method: str, index: Index, **kwargs) -> str:
return method


def find_valid_index(values, *, how: str) -> int | None:
def find_valid_index(values, *, how: Literal["first", "last"]) -> int | None:
"""
Retrieves the index of the first valid value.

Expand All @@ -178,7 +179,8 @@ def find_valid_index(values, *, how: str) -> int | None:
-------
int or None
"""
assert how in ["first", "last"]
if how not in ["first", "last"]:
raise ValueError("Value for how argument must be one of : first, last")

if len(values) == 0: # early stop
return None
Expand Down
16 changes: 9 additions & 7 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
DtypeObj,
FrameOrSeries,
IndexLabel,
MergeTypes,
Suffixes,
TimedeltaConvertibleTypes,
)
from pandas.errors import MergeError
from pandas.util._decorators import (
Expand Down Expand Up @@ -92,7 +94,7 @@
def merge(
left: DataFrame | Series,
right: DataFrame | Series,
how: str = "inner",
how: MergeTypes = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
Expand All @@ -101,7 +103,7 @@ def merge(
sort: bool = False,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
validate: str | None = None,
) -> DataFrame:
op = _MergeOperation(
Expand Down Expand Up @@ -331,11 +333,11 @@ def merge_asof(
right_on: IndexLabel | None = None,
left_index: bool = False,
right_index: bool = False,
by=None,
left_by=None,
right_by=None,
by: IndexLabel | None = None,
left_by: Hashable | None = None,
right_by: Hashable | None = None,
suffixes: Suffixes = ("_x", "_y"),
tolerance=None,
tolerance: None | TimedeltaConvertibleTypes = None,
allow_exact_matches: bool = True,
direction: str = "backward",
) -> DataFrame:
Expand Down Expand Up @@ -622,7 +624,7 @@ def __init__(
sort: bool = True,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
validate: str | None = None,
):
_left = _validate_operand(left)
Expand Down
3 changes: 2 additions & 1 deletion pandas/io/excel/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def get_default_engine(ext, mode="reader"):
"xls": "xlwt",
"ods": "odf",
}
assert mode in ["reader", "writer"]
if mode not in ["reader", "writer"]:
raise ValueError('File mode must be either "reader" or "writer".')
if mode == "writer":
# Prefer xlsxwriter over openpyxl if installed
xlsxwriter = import_optional_dependency("xlsxwriter", errors="warn")
Expand Down
6 changes: 3 additions & 3 deletions pandas/tseries/frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _maybe_add_count(base: str, count: float) -> str:
# Frequency comparison


def is_subperiod(source, target) -> bool:
def is_subperiod(source: str | DateOffset, target: str | DateOffset) -> bool:
"""
Returns True if downsampling is possible between source and target
frequencies
Expand Down Expand Up @@ -501,7 +501,7 @@ def is_subperiod(source, target) -> bool:
return False


def is_superperiod(source, target) -> bool:
def is_superperiod(source: str | DateOffset, target: str | DateOffset) -> bool:
"""
Returns True if upsampling is possible between source and target
frequencies
Expand Down Expand Up @@ -559,7 +559,7 @@ def is_superperiod(source, target) -> bool:
return False


def _maybe_coerce_freq(code) -> str:
def _maybe_coerce_freq(code: str | DateOffset) -> str:
"""we might need to coerce a code to a rule_code
and uppercase it

Expand Down
Loading