diff --git a/bigframes/operations/_matplotlib/__init__.py b/bigframes/operations/_matplotlib/__init__.py index 6ffe71139d..5f99d3b50a 100644 --- a/bigframes/operations/_matplotlib/__init__.py +++ b/bigframes/operations/_matplotlib/__init__.py @@ -20,10 +20,11 @@ PLOT_TYPES = typing.Union[type[core.SamplingPlot], type[hist.HistPlot]] PLOT_CLASSES: dict[str, PLOT_TYPES] = { - "hist": hist.HistPlot, - "line": core.LinePlot, "area": core.AreaPlot, + "bar": core.BarPlot, + "line": core.LinePlot, "scatter": core.ScatterPlot, + "hist": hist.HistPlot, } diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 9e59e09877..9292814fc0 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -46,7 +46,8 @@ def result(self): class SamplingPlot(MPLPlot): - @abc.abstractproperty + @property + @abc.abstractmethod def _kind(self): pass @@ -74,18 +75,24 @@ def _compute_plot_data(self): return self._compute_sample_data(self.data) -class LinePlot(SamplingPlot): - @property - def _kind(self) -> typing.Literal["line"]: - return "line" - - class AreaPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["area"]: return "area" +class BarPlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["bar"]: + return "bar" + + +class LinePlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["line"]: + return "line" + + class ScatterPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["scatter"]: diff --git a/bigframes/operations/plotting.py b/bigframes/operations/plotting.py index a45b825354..e9a86be6c9 100644 --- a/bigframes/operations/plotting.py +++ b/bigframes/operations/plotting.py @@ -23,7 +23,7 @@ class PlotAccessor(vendordt.PlotAccessor): __doc__ = vendordt.PlotAccessor.__doc__ - _common_kinds = ("line", "area", "hist") + _common_kinds = ("line", "area", "hist", "bar") _dataframe_kinds = ("scatter",) _all_kinds = _common_kinds + _dataframe_kinds @@ -72,6 +72,14 @@ def area( ): return self(kind="area", x=x, y=y, stacked=stacked, **kwargs) + def bar( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + return self(kind="bar", x=x, y=y, **kwargs) + def scatter( self, x: typing.Optional[typing.Hashable] = None, diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 7be44e0a0f..3624232ea0 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -195,6 +195,18 @@ def test_area(scalars_dfs): tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) +def test_bar(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "float64_col", "int64_too"] + ax = scalars_df[col_names].plot.bar() + pd_ax = scalars_pandas_df[col_names].plot.bar() + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + for line, pd_line in zip(ax.lines, pd_ax.lines): + # Compare y coordinates between the lines + tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) + + def test_scatter(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_names = ["int64_col", "float64_col", "int64_too", "bool_col"] diff --git a/third_party/bigframes_vendored/pandas/plotting/_core.py b/third_party/bigframes_vendored/pandas/plotting/_core.py index 2409068fa8..4ed5c8eb0b 100644 --- a/third_party/bigframes_vendored/pandas/plotting/_core.py +++ b/third_party/bigframes_vendored/pandas/plotting/_core.py @@ -215,6 +215,66 @@ def area( """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def bar( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + """ + Draw a vertical bar plot. + + This function calls `pandas.plot` to generate a plot with a random sample + of items. For consistent results, the random sampling is reproducible. + Use the `sampling_random_state` parameter to modify the sampling seed. + + **Examples:** + + Basic plot. + + >>> import bigframes.pandas as bpd + >>> bpd.options.display.progress_bar = None + >>> df = bpd.DataFrame({'lab':['A', 'B', 'C'], 'val':[10, 30, 20]}) + >>> ax = df.plot.bar(x='lab', y='val', rot=0) + + Plot a whole dataframe to a bar plot. Each column is assigned a distinct color, + and each row is nested in a group along the horizontal axis. + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = ['snail', 'pig', 'elephant', + ... 'rabbit', 'giraffe', 'coyote', 'horse'] + >>> df = bpd.DataFrame({'speed': speed, 'lifespan': lifespan}, index=index) + >>> ax = df.plot.bar(rot=0) + + Plot stacked bar charts for the DataFrame. + + >>> ax = df.plot.bar(stacked=True) + + If you don’t like the default colours, you can specify how you’d like each column + to be colored. + + >>> axes = df.plot.bar( + ... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"} + ... ) + + Args: + x (label or position, optional): + Allows plotting of one column versus another. If not specified, the index + of the DataFrame is used. + y (label or position, optional): + Allows plotting of one column versus another. If not specified, all numerical + columns are used. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + matplotlib.axes.Axes or numpy.ndarray: + Area plot, or array of area plots if subplots is True. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def scatter( self, x: typing.Optional[typing.Hashable] = None,