From 0fae2e0291ec8d22341b5b543e8f1b384f83cd3c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 21 Nov 2024 13:40:40 -0800 Subject: [PATCH] feat: (Series | DataFrame).plot.bar (#1152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: (Series | DataFrame).plot.bar * add warning message * fix mypy * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot --- bigframes/operations/_matplotlib/__init__.py | 5 +- bigframes/operations/_matplotlib/core.py | 44 +++++++++++--- bigframes/operations/plotting.py | 10 +++- .../system/small/operations/test_plotting.py | 12 ++++ .../pandas/plotting/_core.py | 60 +++++++++++++++++++ 5 files changed, 121 insertions(+), 10 deletions(-) diff --git a/bigframes/operations/_matplotlib/__init__.py b/bigframes/operations/_matplotlib/__init__.py index 6ffe71139d..5f99d3b50a 100644 --- a/bigframes/operations/_matplotlib/__init__.py +++ b/bigframes/operations/_matplotlib/__init__.py @@ -20,10 +20,11 @@ PLOT_TYPES = typing.Union[type[core.SamplingPlot], type[hist.HistPlot]] PLOT_CLASSES: dict[str, PLOT_TYPES] = { - "hist": hist.HistPlot, - "line": core.LinePlot, "area": core.AreaPlot, + "bar": core.BarPlot, + "line": core.LinePlot, "scatter": core.ScatterPlot, + "hist": hist.HistPlot, } diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 9e59e09877..b7c926be99 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -14,6 +14,7 @@ import abc import typing +import warnings import bigframes_vendored.constants as constants import pandas as pd @@ -46,10 +47,15 @@ def result(self): class SamplingPlot(MPLPlot): - @abc.abstractproperty + @property + @abc.abstractmethod def _kind(self): pass + @property + def _sampling_warning_msg(self) -> typing.Optional[str]: + return None + def __init__(self, data, **kwargs) -> None: self.kwargs = kwargs self.data = data @@ -61,6 +67,15 @@ def generate(self) -> None: def _compute_sample_data(self, data): # TODO: Cache the sampling data in the PlotAccessor. sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N) + if self._sampling_warning_msg is not None: + total_n = data.shape[0] + if sampling_n < total_n: + warnings.warn( + self._sampling_warning_msg.format( + sampling_n=sampling_n, total_n=total_n + ) + ) + sampling_random_state = self.kwargs.pop( "sampling_random_state", DEFAULT_SAMPLING_STATE ) @@ -74,18 +89,33 @@ def _compute_plot_data(self): return self._compute_sample_data(self.data) -class LinePlot(SamplingPlot): - @property - def _kind(self) -> typing.Literal["line"]: - return "line" - - class AreaPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["area"]: return "area" +class BarPlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["bar"]: + return "bar" + + @property + def _sampling_warning_msg(self) -> typing.Optional[str]: + return ( + "To optimize plotting performance, your data has been downsampled to {sampling_n} " + "rows from the original {total_n} rows. This may result in some data points " + "not being displayed. For a more comprehensive view, consider pre-processing " + "your data by aggregating it or selecting the top categories." + ) + + +class LinePlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["line"]: + return "line" + + class ScatterPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["scatter"]: diff --git a/bigframes/operations/plotting.py b/bigframes/operations/plotting.py index a45b825354..e9a86be6c9 100644 --- a/bigframes/operations/plotting.py +++ b/bigframes/operations/plotting.py @@ -23,7 +23,7 @@ class PlotAccessor(vendordt.PlotAccessor): __doc__ = vendordt.PlotAccessor.__doc__ - _common_kinds = ("line", "area", "hist") + _common_kinds = ("line", "area", "hist", "bar") _dataframe_kinds = ("scatter",) _all_kinds = _common_kinds + _dataframe_kinds @@ -72,6 +72,14 @@ def area( ): return self(kind="area", x=x, y=y, stacked=stacked, **kwargs) + def bar( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + return self(kind="bar", x=x, y=y, **kwargs) + def scatter( self, x: typing.Optional[typing.Hashable] = None, diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 7be44e0a0f..3624232ea0 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -195,6 +195,18 @@ def test_area(scalars_dfs): tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) +def test_bar(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "float64_col", "int64_too"] + ax = scalars_df[col_names].plot.bar() + pd_ax = scalars_pandas_df[col_names].plot.bar() + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + for line, pd_line in zip(ax.lines, pd_ax.lines): + # Compare y coordinates between the lines + tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) + + def test_scatter(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_names = ["int64_col", "float64_col", "int64_too", "bool_col"] diff --git a/third_party/bigframes_vendored/pandas/plotting/_core.py b/third_party/bigframes_vendored/pandas/plotting/_core.py index 2409068fa8..4ed5c8eb0b 100644 --- a/third_party/bigframes_vendored/pandas/plotting/_core.py +++ b/third_party/bigframes_vendored/pandas/plotting/_core.py @@ -215,6 +215,66 @@ def area( """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def bar( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + """ + Draw a vertical bar plot. + + This function calls `pandas.plot` to generate a plot with a random sample + of items. For consistent results, the random sampling is reproducible. + Use the `sampling_random_state` parameter to modify the sampling seed. + + **Examples:** + + Basic plot. + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + >>> df = bpd.DataFrame({'lab':['A', 'B', 'C'], 'val':[10, 30, 20]}) + >>> ax = df.plot.bar(x='lab', y='val', rot=0) + + Plot a whole dataframe to a bar plot. Each column is assigned a distinct color, + and each row is nested in a group along the horizontal axis. + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = ['snail', 'pig', 'elephant', + ... 'rabbit', 'giraffe', 'coyote', 'horse'] + >>> df = bpd.DataFrame({'speed': speed, 'lifespan': lifespan}, index=index) + >>> ax = df.plot.bar(rot=0) + + Plot stacked bar charts for the DataFrame. + + >>> ax = df.plot.bar(stacked=True) + + If you don’t like the default colours, you can specify how you’d like each column + to be colored. + + >>> axes = df.plot.bar( + ... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"} + ... ) + + Args: + x (label or position, optional): + Allows plotting of one column versus another. If not specified, the index + of the DataFrame is used. + y (label or position, optional): + Allows plotting of one column versus another. If not specified, all numerical + columns are used. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + matplotlib.axes.Axes or numpy.ndarray: + Area plot, or array of area plots if subplots is True. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def scatter( self, x: typing.Optional[typing.Hashable] = None,