diff --git a/pandas-stubs/core/reshape/concat.pyi b/pandas-stubs/core/reshape/concat.pyi index f684a63f..54de1be3 100644 --- a/pandas-stubs/core/reshape/concat.pyi +++ b/pandas-stubs/core/reshape/concat.pyi @@ -12,8 +12,10 @@ from pandas import ( DataFrame, Series, ) +from typing_extensions import Never from pandas._typing import ( + Axis, AxisColumn, AxisIndex, HashableT1, @@ -24,9 +26,23 @@ from pandas._typing import ( @overload def concat( - objs: Iterable[DataFrame] | Mapping[HashableT1, DataFrame], + objs: Iterable[None] | Mapping[HashableT1, None], *, - axis: AxisIndex = ..., + axis: Axis = ..., + join: Literal["inner", "outer"] = ..., + ignore_index: bool = ..., + keys: Iterable[HashableT2] = ..., + levels: Sequence[list[HashableT3] | tuple[HashableT3, ...]] = ..., + names: list[HashableT4] = ..., + verify_integrity: bool = ..., + sort: bool = ..., + copy: bool = ..., +) -> Never: ... +@overload +def concat( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + objs: Iterable[DataFrame | None] | Mapping[HashableT1, DataFrame | None], + *, + axis: Axis = ..., join: Literal["inner", "outer"] = ..., ignore_index: bool = ..., keys: Iterable[HashableT2] = ..., @@ -38,7 +54,7 @@ def concat( ) -> DataFrame: ... @overload def concat( - objs: Iterable[Series] | Mapping[HashableT1, Series], + objs: Iterable[Series | None] | Mapping[HashableT1, Series | None], *, axis: AxisIndex = ..., join: Literal["inner", "outer"] = ..., @@ -52,7 +68,10 @@ def concat( ) -> Series: ... @overload def concat( - objs: Iterable[Series | DataFrame] | Mapping[HashableT1, Series | DataFrame], + objs: ( + Iterable[Series | DataFrame | None] + | Mapping[HashableT1, Series | DataFrame | None] + ), *, axis: AxisColumn, join: Literal["inner", "outer"] = ..., diff --git a/tests/test_pandas.py b/tests/test_pandas.py index ba154a8f..71551bbd 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -17,7 +17,10 @@ # TODO: github.com/pandas-dev/pandas/issues/55023 import pytest -from typing_extensions import assert_type +from typing_extensions import ( + Never, + assert_type, +) from pandas._libs.missing import NAType from pandas._libs.tslibs import NaTType @@ -49,6 +52,30 @@ def test_types_to_datetime() -> None: ) +def test_types_concat_none() -> None: + """Test concatenation with None values.""" + series = pd.Series([7, -5, 10]) + df = pd.DataFrame({"a": [7, -5, 10]}) + + check(assert_type(pd.concat([None, series]), pd.Series), pd.Series) + check(assert_type(pd.concat([None, df]), pd.DataFrame), pd.DataFrame) + check( + assert_type(pd.concat([None, series, df], axis=1), pd.DataFrame), pd.DataFrame + ) + + check(assert_type(pd.concat({"a": None, "b": series}), pd.Series), pd.Series) + check(assert_type(pd.concat({"a": None, "b": df}), pd.DataFrame), pd.DataFrame) + check( + assert_type(pd.concat({"a": None, "b": series, "c": df}, axis=1), pd.DataFrame), + pd.DataFrame, + ) + + if TYPE_CHECKING_INVALID_USAGE: + # using assert_type as otherwise the second call would not be type-checked + assert_type(pd.concat({"a": None}), Never) + assert_type(pd.concat([None]), Never) + + def test_types_concat() -> None: s: pd.Series = pd.Series([0, 1, -10]) s2: pd.Series = pd.Series([7, -5, 10])