Skip to content

Commit

Permalink
Fix calling df.agg with scalar dict values (#861)
Browse files Browse the repository at this point in the history
Fixes #846
  • Loading branch information
hamdanal authored Feb 12, 2024
1 parent 176805d commit 6aae41c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ from pandas._typing import (
S1,
AggFuncTypeBase,
AggFuncTypeDictFrame,
AggFuncTypeDictSeries,
AggFuncTypeFrame,
AnyArrayLike,
ArrayLike,
Expand Down Expand Up @@ -1140,7 +1141,9 @@ class DataFrame(NDFrame, OpsMixin):
) -> DataFrame: ...
def diff(self, periods: int = ..., axis: Axis = ...) -> DataFrame: ...
@overload
def agg(self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs) -> Series: ...
def agg( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs
) -> Series: ...
@overload
def agg(
self,
Expand All @@ -1149,8 +1152,8 @@ class DataFrame(NDFrame, OpsMixin):
**kwargs,
) -> DataFrame: ...
@overload
def aggregate(
self, func: AggFuncTypeBase, axis: Axis = ..., **kwargs
def aggregate( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, func: AggFuncTypeBase | AggFuncTypeDictSeries, axis: Axis = ..., **kwargs
) -> Series: ...
@overload
def aggregate(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,10 @@ def test_types_agg() -> None:
assert_type(df.agg({"A": ["min", "max"], "B": "min"}), pd.DataFrame),
pd.DataFrame,
)
check(assert_type(df.agg({"A": ["mean"]}), pd.DataFrame), pd.DataFrame)
check(assert_type(df.agg("mean", axis=1), pd.Series), pd.Series)
check(assert_type(df.agg({"A": "mean"}), pd.Series), pd.Series)
check(assert_type(df.agg({"A": "mean", "B": "sum"}), pd.Series), pd.Series)


def test_types_aggregate() -> None:
Expand All @@ -1299,6 +1302,10 @@ def test_types_aggregate() -> None:
assert_type(df.aggregate({"A": ["min", "max"], "B": "min"}), pd.DataFrame),
pd.DataFrame,
)
check(assert_type(df.aggregate({"A": ["mean"]}), pd.DataFrame), pd.DataFrame)
check(assert_type(df.aggregate("mean", axis=1), pd.Series), pd.Series)
check(assert_type(df.aggregate({"A": "mean"}), pd.Series), pd.Series)
check(assert_type(df.aggregate({"A": "mean", "B": "sum"}), pd.Series), pd.Series)


def test_types_transform() -> None:
Expand Down

0 comments on commit 6aae41c

Please sign in to comment.