Skip to content

Commit

Permalink
Add unit property and as_unit method to TimestampSeries and Timedelta…
Browse files Browse the repository at this point in the history
…Series
  • Loading branch information
skatsuta committed Feb 12, 2024
1 parent dc84f57 commit cd35a2f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
36 changes: 22 additions & 14 deletions pandas-stubs/core/indexes/accessors.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -154,40 +154,40 @@ class _DatetimeLikeOps(
# type of the series, we don't know which kind of series was ...ed
# in to the dt accessor

_DTRoundingMethodReturnType = TypeVar(
"_DTRoundingMethodReturnType",
_DTTimestampTimedeltaReturnType = TypeVar(
"_DTTimestampTimedeltaReturnType",
Series,
TimedeltaSeries,
TimestampSeries,
TimedeltaSeries,
DatetimeIndex,
TimedeltaIndex,
)

class _DatetimeRoundingMethods(Generic[_DTRoundingMethodReturnType]):
class _DatetimeRoundingMethods(Generic[_DTTimestampTimedeltaReturnType]):
def round(
self,
freq: str | BaseOffset | None,
ambiguous: Literal["raise", "infer", "NaT"] | np_ndarray_bool = ...,
nonexistent: Literal["shift_forward", "shift_backward", "NaT", "raise"]
| timedelta
| Timedelta = ...,
) -> _DTRoundingMethodReturnType: ...
) -> _DTTimestampTimedeltaReturnType: ...
def floor(
self,
freq: str | BaseOffset | None,
ambiguous: Literal["raise", "infer", "NaT"] | np_ndarray_bool = ...,
nonexistent: Literal["shift_forward", "shift_backward", "NaT", "raise"]
| timedelta
| Timedelta = ...,
) -> _DTRoundingMethodReturnType: ...
) -> _DTTimestampTimedeltaReturnType: ...
def ceil(
self,
freq: str | BaseOffset | None,
ambiguous: Literal["raise", "infer", "NaT"] | np_ndarray_bool = ...,
nonexistent: Literal["shift_forward", "shift_backward", "NaT", "raise"]
| timedelta
| Timedelta = ...,
) -> _DTRoundingMethodReturnType: ...
) -> _DTTimestampTimedeltaReturnType: ...

_DTNormalizeReturnType = TypeVar(
"_DTNormalizeReturnType", TimestampSeries, DatetimeIndex
Expand All @@ -196,9 +196,9 @@ _DTStrKindReturnType = TypeVar("_DTStrKindReturnType", Series[str], Index)
_DTToPeriodReturnType = TypeVar("_DTToPeriodReturnType", PeriodSeries, PeriodIndex)

class _DatetimeLikeNoTZMethods(
_DatetimeRoundingMethods[_DTRoundingMethodReturnType],
_DatetimeRoundingMethods[_DTTimestampTimedeltaReturnType],
Generic[
_DTRoundingMethodReturnType,
_DTTimestampTimedeltaReturnType,
_DTNormalizeReturnType,
_DTStrKindReturnType,
_DTToPeriodReturnType,
Expand Down Expand Up @@ -230,15 +230,15 @@ class _DatetimeNoTZProperties(
_DTFreqReturnType,
],
_DatetimeLikeNoTZMethods[
_DTRoundingMethodReturnType,
_DTTimestampTimedeltaReturnType,
_DTNormalizeReturnType,
_DTStrKindReturnType,
_DTToPeriodReturnType,
],
Generic[
_DTFieldOpsReturnType,
_DTBoolOpsReturnType,
_DTRoundingMethodReturnType,
_DTTimestampTimedeltaReturnType,
_DTOtherOpsDateReturnType,
_DTOtherOpsTimeReturnType,
_DTFreqReturnType,
Expand All @@ -253,7 +253,7 @@ class DatetimeProperties(
_DatetimeNoTZProperties[
_DTFieldOpsReturnType,
_DTBoolOpsReturnType,
_DTRoundingMethodReturnType,
_DTTimestampTimedeltaReturnType,
_DTOtherOpsDateReturnType,
_DTOtherOpsTimeReturnType,
_DTFreqReturnType,
Expand All @@ -264,7 +264,7 @@ class DatetimeProperties(
Generic[
_DTFieldOpsReturnType,
_DTBoolOpsReturnType,
_DTRoundingMethodReturnType,
_DTTimestampTimedeltaReturnType,
_DTOtherOpsDateReturnType,
_DTOtherOpsTimeReturnType,
_DTFreqReturnType,
Expand All @@ -275,6 +275,11 @@ class DatetimeProperties(
):
def to_pydatetime(self) -> np.ndarray: ...
def isocalendar(self) -> DataFrame: ...
@property
def unit(self) -> str: ...
def as_unit(
self, unit: Literal["s", "ms", "us", "ns"]
) -> _DTTimestampTimedeltaReturnType: ...

_TDNoRoundingMethodReturnType = TypeVar(
"_TDNoRoundingMethodReturnType", Series[int], Index
Expand All @@ -301,7 +306,10 @@ class TimedeltaProperties(
Properties,
_TimedeltaPropertiesNoRounding[Series[int], Series[float]],
_DatetimeRoundingMethods[TimedeltaSeries],
): ...
):
@property
def unit(self) -> str: ...
def as_unit(self, unit: Literal["s", "ms", "us", "ns"]) -> TimedeltaSeries: ...

_PeriodDTReturnTypes = TypeVar("_PeriodDTReturnTypes", TimestampSeries, DatetimeIndex)
_PeriodIntReturnTypes = TypeVar("_PeriodIntReturnTypes", Series[int], Index[int])
Expand Down
35 changes: 35 additions & 0 deletions tests/test_timefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Any,
Optional,
cast,
)

import numpy as np
Expand Down Expand Up @@ -428,6 +429,11 @@ def test_series_dt_accessors() -> None:
)
check(assert_type(s0.dt.month_name(), "pd.Series[str]"), pd.Series, str)
check(assert_type(s0.dt.day_name(), "pd.Series[str]"), pd.Series, str)
check(assert_type(s0.dt.unit, str), str)
check(assert_type(s0.dt.as_unit("s"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ms"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("us"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ns"), "TimestampSeries"), pd.Series, pd.Timestamp)

i1 = pd.period_range(start="2022-06-01", periods=10)

Expand Down Expand Up @@ -455,6 +461,35 @@ def test_series_dt_accessors() -> None:
check(assert_type(s2.dt.components, pd.DataFrame), pd.DataFrame)
check(assert_type(s2.dt.to_pytimedelta(), np.ndarray), np.ndarray)
check(assert_type(s2.dt.total_seconds(), "pd.Series[float]"), pd.Series, float)
check(assert_type(s2.dt.unit, str), str)
check(assert_type(s2.dt.as_unit("s"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ms"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("us"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ns"), "TimedeltaSeries"), pd.Series, pd.Timedelta)

# Checks for general Series other than TimestampSeries and TimedeltaSeries

s4 = cast(
"pd.Series[pd.Timestamp]",
pd.Series([pd.Timestamp("2024-01-01"), pd.Timestamp("2024-01-02")]),
)

check(assert_type(s4.dt.unit, str), str)
check(assert_type(s4.dt.as_unit("s"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s4.dt.as_unit("ms"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s4.dt.as_unit("us"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s4.dt.as_unit("ns"), pd.Series), pd.Series, pd.Timestamp)

s5 = cast(
"pd.Series[pd.Timedelta]",
pd.Series([pd.Timedelta("1 day"), pd.Timedelta("2 days")]),
)

check(assert_type(s5.dt.unit, str), str)
check(assert_type(s5.dt.as_unit("s"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s5.dt.as_unit("ms"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s5.dt.as_unit("us"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s5.dt.as_unit("ns"), pd.Series), pd.Series, pd.Timedelta)


def test_datetimeindex_accessors() -> None:
Expand Down

0 comments on commit cd35a2f

Please sign in to comment.