Skip to content

Commit

Permalink
feat: (Series|DataFrame).plot (#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin authored Mar 15, 2024
1 parent 91bd39e commit 1c3e668
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 30 deletions.
57 changes: 28 additions & 29 deletions bigframes/operations/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,45 @@
class PlotAccessor(vendordt.PlotAccessor):
__doc__ = vendordt.PlotAccessor.__doc__

_common_kinds = ("line", "area", "hist")
_dataframe_kinds = ("scatter",)
_all_kinds = _common_kinds + _dataframe_kinds

def __call__(self, **kwargs):
import bigframes.series as series

if kwargs.pop("backend", None) is not None:
raise NotImplementedError(
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
)

kind = kwargs.pop("kind", "line")
if kind not in self._all_kinds:
raise NotImplementedError(
f"{kind} is not a valid plot kind supported for now. {constants.FEEDBACK_LINK}"
)

data = self._parent.copy()
if kind in self._dataframe_kinds and isinstance(data, series.Series):
raise ValueError(f"plot kind {kind} can only be used for data frames")

return bfplt.plot(data, kind=kind, **kwargs)

def __init__(self, data) -> None:
self._parent = data

def hist(
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs
):
if kwargs.pop("backend", None) is not None:
raise NotImplementedError(
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
)
return bfplt.plot(self._parent.copy(), kind="hist", by=by, bins=bins, **kwargs)
return self(kind="hist", by=by, bins=bins, **kwargs)

def line(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="line",
x=x,
y=y,
**kwargs,
)
return self(kind="line", x=x, y=y, **kwargs)

def area(
self,
Expand All @@ -56,14 +70,7 @@ def area(
stacked: bool = True,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="area",
x=x,
y=y,
stacked=stacked,
**kwargs,
)
return self(kind="area", x=x, y=y, stacked=stacked, **kwargs)

def scatter(
self,
Expand All @@ -73,12 +80,4 @@ def scatter(
c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="scatter",
x=x,
y=y,
s=s,
c=c,
**kwargs,
)
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwargs)
28 changes: 28 additions & 0 deletions tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,31 @@ def test_sampling_plot_args_random_state():
msg = "numpy array are different"
with pytest.raises(AssertionError, match=msg):
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])


@pytest.mark.parametrize(
("kind", "col_names", "kwargs"),
[
pytest.param("hist", ["int64_col", "int64_too"], {}),
pytest.param("line", ["int64_col", "int64_too"], {}),
pytest.param("area", ["int64_col", "int64_too"], {"stacked": False}),
pytest.param(
"scatter", ["int64_col", "int64_too"], {"x": "int64_col", "y": "int64_too"}
),
pytest.param(
"scatter",
["int64_col"],
{},
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
"uknown",
["int64_col", "int64_too"],
{},
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
def test_plot_call(scalars_dfs, kind, col_names, kwargs):
scalars_df, _ = scalars_dfs
scalars_df[col_names].plot(kind=kind, **kwargs)
38 changes: 37 additions & 1 deletion third_party/bigframes_vendored/pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,43 @@


class PlotAccessor:
"""Make plots of Series or DataFrame with the `matplotlib` backend."""
"""
Make plots of Series or DataFrame with the `matplotlib` backend.
**Examples:**
For Series:
>>> import bigframes.pandas as bpd
>>> ser = bpd.Series([1, 2, 3, 3])
>>> plot = ser.plot(kind='hist', title="My plot")
For DataFrame:
>>> df = bpd.DataFrame({'length': [1.5, 0.5, 1.2, 0.9, 3],
... 'width': [0.7, 0.2, 0.15, 0.2, 1.1]},
... index=['pig', 'rabbit', 'duck', 'chicken', 'horse'])
>>> plot = df.plot(title="DataFrame Plot")
Args:
data (Series or DataFrame):
The object for which the method is called.
kind (str):
The kind of plot to produce:
- 'line' : line plot (default)
- 'hist' : histogram
- 'area' : area plot
- 'scatter' : scatter plot (DataFrame only)
**kwargs:
Options to pass to `pandas.DataFrame.plot` method. See pandas
documentation online for more on these arguments.
Returns:
matplotlib.axes.Axes or np.ndarray of them:
An ndarray is returned with one :class:`matplotlib.axes.Axes`
per column when ``subplots=True``.
"""

def hist(
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs
Expand Down

0 comments on commit 1c3e668

Please sign in to comment.