From a41b5450bde541d684d9415bb5fb2b96ef4f4237 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 9 Nov 2023 10:20:06 -0800 Subject: [PATCH] REF: Mess with data less outside of MPLPlot.__init__ (#55889) * REF: Mess with data less outside of MPLPlot.__init__ * lint fixup * pyright ignore * pyright ignore --- pandas/plotting/_matplotlib/core.py | 96 +++++++++++++++++------------ pandas/plotting/_matplotlib/hist.py | 7 ++- 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 05cf151992c96..d59220a1f97f8 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -89,11 +89,15 @@ from pandas._typing import ( IndexLabel, + NDFrameT, PlottingOrientation, npt, ) - from pandas import Series + from pandas import ( + PeriodIndex, + Series, + ) def _color_in_style(style: str) -> bool: @@ -161,8 +165,6 @@ def __init__( ) -> None: import matplotlib.pyplot as plt - self.data = data - # if users assign an empty list or tuple, raise `ValueError` # similar to current `df.box` and `df.hist` APIs. if by in ([], ()): @@ -193,9 +195,11 @@ def __init__( self.kind = kind - self.subplots = self._validate_subplots_kwarg(subplots) + self.subplots = type(self)._validate_subplots_kwarg( + subplots, data, kind=self._kind + ) - self.sharex = self._validate_sharex(sharex, ax, by) + self.sharex = type(self)._validate_sharex(sharex, ax, by) self.sharey = sharey self.figsize = figsize self.layout = layout @@ -245,10 +249,11 @@ def __init__( # parse errorbar input if given xerr = kwds.pop("xerr", None) yerr = kwds.pop("yerr", None) - self.errors = { - kw: self._parse_errorbars(kw, err) - for kw, err in zip(["xerr", "yerr"], [xerr, yerr]) - } + nseries = self._get_nseries(data) + xerr, data = type(self)._parse_errorbars("xerr", xerr, data, nseries) + yerr, data = type(self)._parse_errorbars("yerr", yerr, data, nseries) + self.errors = {"xerr": xerr, "yerr": yerr} + self.data = data if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)): secondary_y = [secondary_y] @@ -271,7 +276,8 @@ def __init__( self._validate_color_args() @final - def _validate_sharex(self, sharex: bool | None, ax, by) -> bool: + @staticmethod + def _validate_sharex(sharex: bool | None, ax, by) -> bool: if sharex is None: # if by is defined, subplots are used and sharex should be False if ax is None and by is None: # pylint: disable=simplifiable-if-statement @@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool: return bool(sharex) @final + @staticmethod def _validate_subplots_kwarg( - self, subplots: bool | Sequence[Sequence[str]] + subplots: bool | Sequence[Sequence[str]], data: Series | DataFrame, kind: str ) -> bool | list[tuple[int, ...]]: """ Validate the subplots parameter @@ -323,18 +330,18 @@ def _validate_subplots_kwarg( "area", "pie", ) - if self._kind not in supported_kinds: + if kind not in supported_kinds: raise ValueError( "When subplots is an iterable, kind must be " - f"one of {', '.join(supported_kinds)}. Got {self._kind}." + f"one of {', '.join(supported_kinds)}. Got {kind}." ) - if isinstance(self.data, ABCSeries): + if isinstance(data, ABCSeries): raise NotImplementedError( "An iterable subplots for a Series is not supported." ) - columns = self.data.columns + columns = data.columns if isinstance(columns, ABCMultiIndex): raise NotImplementedError( "An iterable subplots for a DataFrame with a MultiIndex column " @@ -442,18 +449,22 @@ def _iter_data( # typing. yield col, np.asarray(values.values) - @property - def nseries(self) -> int: + def _get_nseries(self, data: Series | DataFrame) -> int: # When `by` is explicitly assigned, grouped data size will be defined, and # this will determine number of subplots to have, aka `self.nseries` - if self.data.ndim == 1: + if data.ndim == 1: return 1 elif self.by is not None and self._kind == "hist": return len(self._grouped) elif self.by is not None and self._kind == "box": return len(self.columns) else: - return self.data.shape[1] + return data.shape[1] + + @final + @property + def nseries(self) -> int: + return self._get_nseries(self.data) @final def draw(self) -> None: @@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False): index = self.data.index is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time") + x: list[int] | np.ndarray if self.use_index: if convert_period and isinstance(index, ABCPeriodIndex): self.data = self.data.reindex(index=index.sort_values()) - x = self.data.index.to_timestamp()._mpl_repr() + index = cast("PeriodIndex", self.data.index) + x = index.to_timestamp()._mpl_repr() elif is_any_real_numeric_dtype(index.dtype): # Matplotlib supports numeric values or datetime objects as # xaxis values. Taking LBYL approach here, by the time @@ -1050,8 +1063,12 @@ def _get_colors( color=self.kwds.get(color_kwds), ) + # TODO: tighter typing for first return? @final - def _parse_errorbars(self, label: str, err): + @staticmethod + def _parse_errorbars( + label: str, err, data: NDFrameT, nseries: int + ) -> tuple[Any, NDFrameT]: """ Look for error keyword arguments and return the actual errorbar data or return the error DataFrame/dict @@ -1071,7 +1088,7 @@ def _parse_errorbars(self, label: str, err): should be in a ``Mx2xN`` array. """ if err is None: - return None + return None, data def match_labels(data, e): e = e.reindex(data.index) @@ -1079,7 +1096,7 @@ def match_labels(data, e): # key-matched DataFrame if isinstance(err, ABCDataFrame): - err = match_labels(self.data, err) + err = match_labels(data, err) # key-matched dict elif isinstance(err, dict): pass @@ -1087,16 +1104,16 @@ def match_labels(data, e): # Series of error values elif isinstance(err, ABCSeries): # broadcast error series across data - err = match_labels(self.data, err) + err = match_labels(data, err) err = np.atleast_2d(err) - err = np.tile(err, (self.nseries, 1)) + err = np.tile(err, (nseries, 1)) # errors are a column in the dataframe elif isinstance(err, str): - evalues = self.data[err].values - self.data = self.data[self.data.columns.drop(err)] + evalues = data[err].values + data = data[data.columns.drop(err)] err = np.atleast_2d(evalues) - err = np.tile(err, (self.nseries, 1)) + err = np.tile(err, (nseries, 1)) elif is_list_like(err): if is_iterator(err): @@ -1108,40 +1125,40 @@ def match_labels(data, e): err_shape = err.shape # asymmetrical error bars - if isinstance(self.data, ABCSeries) and err_shape[0] == 2: + if isinstance(data, ABCSeries) and err_shape[0] == 2: err = np.expand_dims(err, 0) err_shape = err.shape - if err_shape[2] != len(self.data): + if err_shape[2] != len(data): raise ValueError( "Asymmetrical error bars should be provided " - f"with the shape (2, {len(self.data)})" + f"with the shape (2, {len(data)})" ) - elif isinstance(self.data, ABCDataFrame) and err.ndim == 3: + elif isinstance(data, ABCDataFrame) and err.ndim == 3: if ( - (err_shape[0] != self.nseries) + (err_shape[0] != nseries) or (err_shape[1] != 2) - or (err_shape[2] != len(self.data)) + or (err_shape[2] != len(data)) ): raise ValueError( "Asymmetrical error bars should be provided " - f"with the shape ({self.nseries}, 2, {len(self.data)})" + f"with the shape ({nseries}, 2, {len(data)})" ) # broadcast errors to each data series if len(err) == 1: - err = np.tile(err, (self.nseries, 1)) + err = np.tile(err, (nseries, 1)) elif is_number(err): err = np.tile( [err], # pyright: ignore[reportGeneralTypeIssues] - (self.nseries, len(self.data)), + (nseries, len(data)), ) else: msg = f"No valid {label} detected" raise ValueError(msg) - return err + return err, data # pyright: ignore[reportGeneralTypeIssues] @final def _get_errorbars( @@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None: self.y = y @final - @property - def nseries(self) -> int: + def _get_nseries(self, data: Series | DataFrame) -> int: return 1 @final diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index ed93dc740e1b4..e42914a9802dd 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -45,7 +45,10 @@ from pandas._typing import PlottingOrientation - from pandas import DataFrame + from pandas import ( + DataFrame, + Series, + ) class HistPlot(LinePlot): @@ -87,7 +90,7 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]): bins = self._calculate_bins(self.data, bins) return bins - def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray: + def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray: """Calculate bins given data""" nd_values = data.infer_objects(copy=False)._get_numeric_data() values = np.ravel(nd_values)