Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: (Series | DataFrame).plot.bar #1152

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bigframes/operations/_matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
PLOT_TYPES = typing.Union[type[core.SamplingPlot], type[hist.HistPlot]]

PLOT_CLASSES: dict[str, PLOT_TYPES] = {
"hist": hist.HistPlot,
"line": core.LinePlot,
"area": core.AreaPlot,
"bar": core.BarPlot,
"line": core.LinePlot,
"scatter": core.ScatterPlot,
"hist": hist.HistPlot,
}


Expand Down
44 changes: 37 additions & 7 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import abc
import typing
import warnings

import bigframes_vendored.constants as constants
import pandas as pd
Expand Down Expand Up @@ -46,10 +47,15 @@ def result(self):


class SamplingPlot(MPLPlot):
@abc.abstractproperty
@property
@abc.abstractmethod
def _kind(self):
pass

@property
def _sampling_warning_msg(self) -> typing.Optional[str]:
return None

def __init__(self, data, **kwargs) -> None:
self.kwargs = kwargs
self.data = data
Expand All @@ -61,6 +67,15 @@ def generate(self) -> None:
def _compute_sample_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N)
if self._sampling_warning_msg is not None:
total_n = data.shape[0]
if sampling_n < total_n:
warnings.warn(
self._sampling_warning_msg.format(
sampling_n=sampling_n, total_n=total_n
)
)

sampling_random_state = self.kwargs.pop(
"sampling_random_state", DEFAULT_SAMPLING_STATE
)
Expand All @@ -74,18 +89,33 @@ def _compute_plot_data(self):
return self._compute_sample_data(self.data)


class LinePlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["line"]:
return "line"


class AreaPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["area"]:
return "area"


class BarPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["bar"]:
return "bar"
Comment on lines +98 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does bar plot make sense as a sampling plot? It seems that bar plots are really meant for small-cardinalities and don't down-sample well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bar plots are discrete while the area/line/scatter plots are continuous. Hence, the downsampling in the bar is not straightforward compared to others. Thanks for suggestions here. I just added a warning message to suggest users to consider pre-processing data (aggregations or select top categories).


@property
def _sampling_warning_msg(self) -> typing.Optional[str]:
return (
"To optimize plotting performance, your data has been downsampled to {sampling_n} "
"rows from the original {total_n} rows. This may result in some data points "
"not being displayed. For a more comprehensive view, consider pre-processing "
"your data by aggregating it or selecting the top categories."
)


class LinePlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["line"]:
return "line"


class ScatterPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["scatter"]:
Expand Down
10 changes: 9 additions & 1 deletion bigframes/operations/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class PlotAccessor(vendordt.PlotAccessor):
__doc__ = vendordt.PlotAccessor.__doc__

_common_kinds = ("line", "area", "hist")
_common_kinds = ("line", "area", "hist", "bar")
_dataframe_kinds = ("scatter",)
_all_kinds = _common_kinds + _dataframe_kinds

Expand Down Expand Up @@ -72,6 +72,14 @@ def area(
):
return self(kind="area", x=x, y=y, stacked=stacked, **kwargs)

def bar(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
return self(kind="bar", x=x, y=y, **kwargs)

def scatter(
self,
x: typing.Optional[typing.Hashable] = None,
Expand Down
12 changes: 12 additions & 0 deletions tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ def test_area(scalars_dfs):
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_bar(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too"]
ax = scalars_df[col_names].plot.bar()
pd_ax = scalars_pandas_df[col_names].plot.bar()
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_scatter(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
Expand Down
60 changes: 60 additions & 0 deletions third_party/bigframes_vendored/pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,66 @@ def area(
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

def bar(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
"""
Draw a vertical bar plot.

This function calls `pandas.plot` to generate a plot with a random sample
of items. For consistent results, the random sampling is reproducible.
Use the `sampling_random_state` parameter to modify the sampling seed.

**Examples:**

Basic plot.

>>> import bigframes.pandas as bpd
>>> bpd.options.display.progress_bar = None
>>> df = bpd.DataFrame({'lab':['A', 'B', 'C'], 'val':[10, 30, 20]})
>>> ax = df.plot.bar(x='lab', y='val', rot=0)

Plot a whole dataframe to a bar plot. Each column is assigned a distinct color,
and each row is nested in a group along the horizontal axis.

>>> speed = [0.1, 17.5, 40, 48, 52, 69, 88]
>>> lifespan = [2, 8, 70, 1.5, 25, 12, 28]
>>> index = ['snail', 'pig', 'elephant',
... 'rabbit', 'giraffe', 'coyote', 'horse']
>>> df = bpd.DataFrame({'speed': speed, 'lifespan': lifespan}, index=index)
>>> ax = df.plot.bar(rot=0)

Plot stacked bar charts for the DataFrame.

>>> ax = df.plot.bar(stacked=True)

If you don’t like the default colours, you can specify how you’d like each column
to be colored.

>>> axes = df.plot.bar(
... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"}
... )

Args:
x (label or position, optional):
Allows plotting of one column versus another. If not specified, the index
of the DataFrame is used.
y (label or position, optional):
Allows plotting of one column versus another. If not specified, all numerical
columns are used.
**kwargs:
Additional keyword arguments are documented in
:meth:`DataFrame.plot`.

Returns:
matplotlib.axes.Axes or numpy.ndarray:
Area plot, or array of area plots if subplots is True.
"""
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

def scatter(
self,
x: typing.Optional[typing.Hashable] = None,
Expand Down