Skip to content

Commit

Permalink
REF: Mess with data less outside of MPLPlot.__init__ (pandas-dev#55889)
Browse files Browse the repository at this point in the history
* REF: Mess with data less outside of MPLPlot.__init__

* lint fixup

* pyright ignore

* pyright ignore
  • Loading branch information
jbrockmendel authored Nov 9, 2023
1 parent 5cedf87 commit a41b545
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 42 deletions.
96 changes: 56 additions & 40 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@

from pandas._typing import (
IndexLabel,
NDFrameT,
PlottingOrientation,
npt,
)

from pandas import Series
from pandas import (
PeriodIndex,
Series,
)


def _color_in_style(style: str) -> bool:
Expand Down Expand Up @@ -161,8 +165,6 @@ def __init__(
) -> None:
import matplotlib.pyplot as plt

self.data = data

# if users assign an empty list or tuple, raise `ValueError`
# similar to current `df.box` and `df.hist` APIs.
if by in ([], ()):
Expand Down Expand Up @@ -193,9 +195,11 @@ def __init__(

self.kind = kind

self.subplots = self._validate_subplots_kwarg(subplots)
self.subplots = type(self)._validate_subplots_kwarg(
subplots, data, kind=self._kind
)

self.sharex = self._validate_sharex(sharex, ax, by)
self.sharex = type(self)._validate_sharex(sharex, ax, by)
self.sharey = sharey
self.figsize = figsize
self.layout = layout
Expand Down Expand Up @@ -245,10 +249,11 @@ def __init__(
# parse errorbar input if given
xerr = kwds.pop("xerr", None)
yerr = kwds.pop("yerr", None)
self.errors = {
kw: self._parse_errorbars(kw, err)
for kw, err in zip(["xerr", "yerr"], [xerr, yerr])
}
nseries = self._get_nseries(data)
xerr, data = type(self)._parse_errorbars("xerr", xerr, data, nseries)
yerr, data = type(self)._parse_errorbars("yerr", yerr, data, nseries)
self.errors = {"xerr": xerr, "yerr": yerr}
self.data = data

if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
secondary_y = [secondary_y]
Expand All @@ -271,7 +276,8 @@ def __init__(
self._validate_color_args()

@final
def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
@staticmethod
def _validate_sharex(sharex: bool | None, ax, by) -> bool:
if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None: # pylint: disable=simplifiable-if-statement
Expand All @@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
return bool(sharex)

@final
@staticmethod
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
subplots: bool | Sequence[Sequence[str]], data: Series | DataFrame, kind: str
) -> bool | list[tuple[int, ...]]:
"""
Validate the subplots parameter
Expand Down Expand Up @@ -323,18 +330,18 @@ def _validate_subplots_kwarg(
"area",
"pie",
)
if self._kind not in supported_kinds:
if kind not in supported_kinds:
raise ValueError(
"When subplots is an iterable, kind must be "
f"one of {', '.join(supported_kinds)}. Got {self._kind}."
f"one of {', '.join(supported_kinds)}. Got {kind}."
)

if isinstance(self.data, ABCSeries):
if isinstance(data, ABCSeries):
raise NotImplementedError(
"An iterable subplots for a Series is not supported."
)

columns = self.data.columns
columns = data.columns
if isinstance(columns, ABCMultiIndex):
raise NotImplementedError(
"An iterable subplots for a DataFrame with a MultiIndex column "
Expand Down Expand Up @@ -442,18 +449,22 @@ def _iter_data(
# typing.
yield col, np.asarray(values.values)

@property
def nseries(self) -> int:
def _get_nseries(self, data: Series | DataFrame) -> int:
# When `by` is explicitly assigned, grouped data size will be defined, and
# this will determine number of subplots to have, aka `self.nseries`
if self.data.ndim == 1:
if data.ndim == 1:
return 1
elif self.by is not None and self._kind == "hist":
return len(self._grouped)
elif self.by is not None and self._kind == "box":
return len(self.columns)
else:
return self.data.shape[1]
return data.shape[1]

@final
@property
def nseries(self) -> int:
return self._get_nseries(self.data)

@final
def draw(self) -> None:
Expand Down Expand Up @@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False):
index = self.data.index
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")

x: list[int] | np.ndarray
if self.use_index:
if convert_period and isinstance(index, ABCPeriodIndex):
self.data = self.data.reindex(index=index.sort_values())
x = self.data.index.to_timestamp()._mpl_repr()
index = cast("PeriodIndex", self.data.index)
x = index.to_timestamp()._mpl_repr()
elif is_any_real_numeric_dtype(index.dtype):
# Matplotlib supports numeric values or datetime objects as
# xaxis values. Taking LBYL approach here, by the time
Expand Down Expand Up @@ -1050,8 +1063,12 @@ def _get_colors(
color=self.kwds.get(color_kwds),
)

# TODO: tighter typing for first return?
@final
def _parse_errorbars(self, label: str, err):
@staticmethod
def _parse_errorbars(
label: str, err, data: NDFrameT, nseries: int
) -> tuple[Any, NDFrameT]:
"""
Look for error keyword arguments and return the actual errorbar data
or return the error DataFrame/dict
Expand All @@ -1071,32 +1088,32 @@ def _parse_errorbars(self, label: str, err):
should be in a ``Mx2xN`` array.
"""
if err is None:
return None
return None, data

def match_labels(data, e):
e = e.reindex(data.index)
return e

# key-matched DataFrame
if isinstance(err, ABCDataFrame):
err = match_labels(self.data, err)
err = match_labels(data, err)
# key-matched dict
elif isinstance(err, dict):
pass

# Series of error values
elif isinstance(err, ABCSeries):
# broadcast error series across data
err = match_labels(self.data, err)
err = match_labels(data, err)
err = np.atleast_2d(err)
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

# errors are a column in the dataframe
elif isinstance(err, str):
evalues = self.data[err].values
self.data = self.data[self.data.columns.drop(err)]
evalues = data[err].values
data = data[data.columns.drop(err)]
err = np.atleast_2d(evalues)
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

elif is_list_like(err):
if is_iterator(err):
Expand All @@ -1108,40 +1125,40 @@ def match_labels(data, e):
err_shape = err.shape

# asymmetrical error bars
if isinstance(self.data, ABCSeries) and err_shape[0] == 2:
if isinstance(data, ABCSeries) and err_shape[0] == 2:
err = np.expand_dims(err, 0)
err_shape = err.shape
if err_shape[2] != len(self.data):
if err_shape[2] != len(data):
raise ValueError(
"Asymmetrical error bars should be provided "
f"with the shape (2, {len(self.data)})"
f"with the shape (2, {len(data)})"
)
elif isinstance(self.data, ABCDataFrame) and err.ndim == 3:
elif isinstance(data, ABCDataFrame) and err.ndim == 3:
if (
(err_shape[0] != self.nseries)
(err_shape[0] != nseries)
or (err_shape[1] != 2)
or (err_shape[2] != len(self.data))
or (err_shape[2] != len(data))
):
raise ValueError(
"Asymmetrical error bars should be provided "
f"with the shape ({self.nseries}, 2, {len(self.data)})"
f"with the shape ({nseries}, 2, {len(data)})"
)

# broadcast errors to each data series
if len(err) == 1:
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

elif is_number(err):
err = np.tile(
[err], # pyright: ignore[reportGeneralTypeIssues]
(self.nseries, len(self.data)),
(nseries, len(data)),
)

else:
msg = f"No valid {label} detected"
raise ValueError(msg)

return err
return err, data # pyright: ignore[reportGeneralTypeIssues]

@final
def _get_errorbars(
Expand Down Expand Up @@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None:
self.y = y

@final
@property
def nseries(self) -> int:
def _get_nseries(self, data: Series | DataFrame) -> int:
return 1

@final
Expand Down
7 changes: 5 additions & 2 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@

from pandas._typing import PlottingOrientation

from pandas import DataFrame
from pandas import (
DataFrame,
Series,
)


class HistPlot(LinePlot):
Expand Down Expand Up @@ -87,7 +90,7 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
bins = self._calculate_bins(self.data, bins)
return bins

def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects(copy=False)._get_numeric_data()
values = np.ravel(nd_values)
Expand Down

0 comments on commit a41b545

Please sign in to comment.