Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: Clean up concat statefullness and validation #57933

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 110 additions & 119 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.common import (
is_bool,
is_iterator,
)
from pandas.core.dtypes.common import is_bool
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import (
ABCDataFrame,
Expand Down Expand Up @@ -423,11 +420,12 @@ def __init__(
self.ignore_index = ignore_index
self.verify_integrity = verify_integrity

objs, keys = self._clean_keys_and_objs(objs, keys)
objs, keys, ndims = _clean_keys_and_objs(objs, keys)

# figure out what our result ndim is going to be
ndims = self._get_ndims(objs)
sample, objs = self._get_sample_object(objs, ndims, keys, names, levels)
# select an object to be our result reference
sample, objs = _get_sample_object(
objs, ndims, keys, names, levels, self.intersect
)

# Standardize axis parameter to int
if sample.ndim == 1:
Expand Down Expand Up @@ -458,100 +456,6 @@ def __init__(
self.names = names or getattr(keys, "names", None)
self.levels = levels

def _get_ndims(self, objs: list[Series | DataFrame]) -> set[int]:
# figure out what our result ndim is going to be
ndims = set()
for obj in objs:
if not isinstance(obj, (ABCSeries, ABCDataFrame)):
msg = (
f"cannot concatenate object of type '{type(obj)}'; "
"only Series and DataFrame objs are valid"
)
raise TypeError(msg)

ndims.add(obj.ndim)
return ndims

def _clean_keys_and_objs(
self,
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
keys,
) -> tuple[list[Series | DataFrame], Index | None]:
if isinstance(objs, abc.Mapping):
if keys is None:
keys = list(objs.keys())
objs_list = [objs[k] for k in keys]
else:
objs_list = list(objs)

if len(objs_list) == 0:
raise ValueError("No objects to concatenate")

if keys is None:
objs_list = list(com.not_none(*objs_list))
else:
# GH#1649
key_indices = []
clean_objs = []
if is_iterator(keys):
keys = list(keys)
if len(keys) != len(objs_list):
# GH#43485
raise ValueError(
f"The length of the keys ({len(keys)}) must match "
f"the length of the objects to concatenate ({len(objs_list)})"
)
for i, obj in enumerate(objs_list):
if obj is not None:
key_indices.append(i)
clean_objs.append(obj)
objs_list = clean_objs

if not isinstance(keys, Index):
keys = Index(keys)

if len(key_indices) < len(keys):
keys = keys.take(key_indices)

if len(objs_list) == 0:
raise ValueError("All objects passed were None")

return objs_list, keys

def _get_sample_object(
self,
objs: list[Series | DataFrame],
ndims: set[int],
keys,
names,
levels,
) -> tuple[Series | DataFrame, list[Series | DataFrame]]:
# get the sample
# want the highest ndim that we have, and must be non-empty
# unless all objs are empty
sample: Series | DataFrame | None = None
if len(ndims) > 1:
max_ndim = max(ndims)
for obj in objs:
if obj.ndim == max_ndim and np.sum(obj.shape):
sample = obj
break

else:
# filter out the empties if we have not multi-index possibilities
# note to keep empty Series as it affect to result columns / name
non_empties = [obj for obj in objs if sum(obj.shape) > 0 or obj.ndim == 1]

if len(non_empties) and (
keys is None and names is None and levels is None and not self.intersect
):
objs = non_empties
sample = objs[0]

if sample is None:
sample = objs[0]
return sample, objs

def _sanitize_mixed_ndim(
self,
objs: list[Series | DataFrame],
Expand Down Expand Up @@ -664,29 +568,24 @@ def get_result(self):
out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
return out.__finalize__(self, method="concat")

def _get_result_dim(self) -> int:
if self._is_series and self.bm_axis == 1:
return 2
else:
return self.objs[0].ndim

@cache_readonly
def new_axes(self) -> list[Index]:
ndim = self._get_result_dim()
if self._is_series and self.bm_axis == 1:
ndim = 2
else:
ndim = self.objs[0].ndim
return [
self._get_concat_axis if i == self.bm_axis else self._get_comb_axis(i)
self._get_concat_axis
if i == self.bm_axis
else get_objs_combined_axis(
self.objs,
axis=self.objs[0]._get_block_manager_axis(i),
intersect=self.intersect,
sort=self.sort,
)
for i in range(ndim)
]

def _get_comb_axis(self, i: AxisInt) -> Index:
data_axis = self.objs[0]._get_block_manager_axis(i)
return get_objs_combined_axis(
self.objs,
axis=data_axis,
intersect=self.intersect,
sort=self.sort,
)

@cache_readonly
def _get_concat_axis(self) -> Index:
"""
Expand Down Expand Up @@ -747,6 +646,98 @@ def _maybe_check_integrity(self, concat_index: Index) -> None:
raise ValueError(f"Indexes have overlapping values: {overlap}")


def _clean_keys_and_objs(
objs: Iterable[Series | DataFrame] | Mapping[HashableT, Series | DataFrame],
keys,
) -> tuple[list[Series | DataFrame], Index | None, set[int]]:
"""
Returns
-------
clean_objs : list[Series | DataFrame]
LIst of DataFrame and Series with Nones removed.
keys : Index | None
None if keys was None
Index if objs was a Mapping or keys was not None. Filtered where objs was None.
ndim : set[int]
Unique .ndim attribute of obj encountered.
"""
if isinstance(objs, abc.Mapping):
if keys is None:
keys = objs.keys()
objs_list = [objs[k] for k in keys]
else:
objs_list = list(objs)

if len(objs_list) == 0:
raise ValueError("No objects to concatenate")

if keys is not None:
if not isinstance(keys, Index):
keys = Index(keys)
if len(keys) != len(objs_list):
# GH#43485
raise ValueError(
f"The length of the keys ({len(keys)}) must match "
f"the length of the objects to concatenate ({len(objs_list)})"
)

# GH#1649
key_indices = []
clean_objs = []
ndims = set()
for i, obj in enumerate(objs_list):
if obj is None:
continue
elif isinstance(obj, (ABCSeries, ABCDataFrame)):
key_indices.append(i)
clean_objs.append(obj)
ndims.add(obj.ndim)
else:
msg = (
f"cannot concatenate object of type '{type(obj)}'; "
"only Series and DataFrame objs are valid"
)
raise TypeError(msg)

if keys is not None and len(key_indices) < len(keys):
keys = keys.take(key_indices)

if len(clean_objs) == 0:
raise ValueError("All objects passed were None")

return clean_objs, keys, ndims


def _get_sample_object(
objs: list[Series | DataFrame],
ndims: set[int],
keys,
names,
levels,
intersect: bool,
) -> tuple[Series | DataFrame, list[Series | DataFrame]]:
# get the sample
# want the highest ndim that we have, and must be non-empty
# unless all objs are empty
if len(ndims) > 1:
max_ndim = max(ndims)
for obj in objs:
if obj.ndim == max_ndim and sum(obj.shape): # type: ignore[arg-type]
return obj, objs
elif keys is None and names is None and levels is None and not intersect:
# filter out the empties if we have not multi-index possibilities
# note to keep empty Series as it affect to result columns / name
if ndims.pop() == 2:
non_empties = [obj for obj in objs if sum(obj.shape)]
else:
non_empties = objs

if len(non_empties):
return non_empties[0], non_empties

return objs[0], objs


def _concat_indexes(indexes) -> Index:
return indexes[0].append(indexes[1:])

Expand Down