Skip to content

Commit

Permalink
feat: (Series | DataFrame).plot.bar (#1152)
Browse files Browse the repository at this point in the history
* feat: (Series | DataFrame).plot.bar

* add warning message

* fix mypy

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
chelsea-lin and gcf-owl-bot[bot] authored Nov 21, 2024
1 parent de923d0 commit 0fae2e0
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 10 deletions.
5 changes: 3 additions & 2 deletions bigframes/operations/_matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
PLOT_TYPES = typing.Union[type[core.SamplingPlot], type[hist.HistPlot]]

PLOT_CLASSES: dict[str, PLOT_TYPES] = {
"hist": hist.HistPlot,
"line": core.LinePlot,
"area": core.AreaPlot,
"bar": core.BarPlot,
"line": core.LinePlot,
"scatter": core.ScatterPlot,
"hist": hist.HistPlot,
}


Expand Down
44 changes: 37 additions & 7 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import abc
import typing
import warnings

import bigframes_vendored.constants as constants
import pandas as pd
Expand Down Expand Up @@ -46,10 +47,15 @@ def result(self):


class SamplingPlot(MPLPlot):
@abc.abstractproperty
@property
@abc.abstractmethod
def _kind(self):
pass

@property
def _sampling_warning_msg(self) -> typing.Optional[str]:
return None

def __init__(self, data, **kwargs) -> None:
self.kwargs = kwargs
self.data = data
Expand All @@ -61,6 +67,15 @@ def generate(self) -> None:
def _compute_sample_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N)
if self._sampling_warning_msg is not None:
total_n = data.shape[0]
if sampling_n < total_n:
warnings.warn(
self._sampling_warning_msg.format(
sampling_n=sampling_n, total_n=total_n
)
)

sampling_random_state = self.kwargs.pop(
"sampling_random_state", DEFAULT_SAMPLING_STATE
)
Expand All @@ -74,18 +89,33 @@ def _compute_plot_data(self):
return self._compute_sample_data(self.data)


class LinePlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["line"]:
return "line"


class AreaPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["area"]:
return "area"


class BarPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["bar"]:
return "bar"

@property
def _sampling_warning_msg(self) -> typing.Optional[str]:
return (
"To optimize plotting performance, your data has been downsampled to {sampling_n} "
"rows from the original {total_n} rows. This may result in some data points "
"not being displayed. For a more comprehensive view, consider pre-processing "
"your data by aggregating it or selecting the top categories."
)


class LinePlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["line"]:
return "line"


class ScatterPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["scatter"]:
Expand Down
10 changes: 9 additions & 1 deletion bigframes/operations/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class PlotAccessor(vendordt.PlotAccessor):
__doc__ = vendordt.PlotAccessor.__doc__

_common_kinds = ("line", "area", "hist")
_common_kinds = ("line", "area", "hist", "bar")
_dataframe_kinds = ("scatter",)
_all_kinds = _common_kinds + _dataframe_kinds

Expand Down Expand Up @@ -72,6 +72,14 @@ def area(
):
return self(kind="area", x=x, y=y, stacked=stacked, **kwargs)

def bar(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
return self(kind="bar", x=x, y=y, **kwargs)

def scatter(
self,
x: typing.Optional[typing.Hashable] = None,
Expand Down
12 changes: 12 additions & 0 deletions tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ def test_area(scalars_dfs):
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_bar(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too"]
ax = scalars_df[col_names].plot.bar()
pd_ax = scalars_pandas_df[col_names].plot.bar()
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_scatter(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
Expand Down
60 changes: 60 additions & 0 deletions third_party/bigframes_vendored/pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,66 @@ def area(
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

def bar(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
"""
Draw a vertical bar plot.
This function calls `pandas.plot` to generate a plot with a random sample
of items. For consistent results, the random sampling is reproducible.
Use the `sampling_random_state` parameter to modify the sampling seed.
**Examples:**
Basic plot.
>>> import bigframes.pandas as bpd
>>> bpd.options.display.progress_bar = None
>>> df = bpd.DataFrame({'lab':['A', 'B', 'C'], 'val':[10, 30, 20]})
>>> ax = df.plot.bar(x='lab', y='val', rot=0)
Plot a whole dataframe to a bar plot. Each column is assigned a distinct color,
and each row is nested in a group along the horizontal axis.
>>> speed = [0.1, 17.5, 40, 48, 52, 69, 88]
>>> lifespan = [2, 8, 70, 1.5, 25, 12, 28]
>>> index = ['snail', 'pig', 'elephant',
... 'rabbit', 'giraffe', 'coyote', 'horse']
>>> df = bpd.DataFrame({'speed': speed, 'lifespan': lifespan}, index=index)
>>> ax = df.plot.bar(rot=0)
Plot stacked bar charts for the DataFrame.
>>> ax = df.plot.bar(stacked=True)
If you don’t like the default colours, you can specify how you’d like each column
to be colored.
>>> axes = df.plot.bar(
... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"}
... )
Args:
x (label or position, optional):
Allows plotting of one column versus another. If not specified, the index
of the DataFrame is used.
y (label or position, optional):
Allows plotting of one column versus another. If not specified, all numerical
columns are used.
**kwargs:
Additional keyword arguments are documented in
:meth:`DataFrame.plot`.
Returns:
matplotlib.axes.Axes or numpy.ndarray:
Area plot, or array of area plots if subplots is True.
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

def scatter(
self,
x: typing.Optional[typing.Hashable] = None,
Expand Down

0 comments on commit 0fae2e0

Please sign in to comment.