Skip to content

Commit

Permalink
feat: (Series|DataFrame).plot.(line|area|scatter) (#431)
Browse files Browse the repository at this point in the history
Fixing internal bugs:
line: b/322177942
scatter: b/322178336
area: b/322178394
  • Loading branch information
chelsea-lin authored Mar 14, 2024
1 parent 7f3d41c commit 0772510
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 30 deletions.
3 changes: 3 additions & 0 deletions bigframes/operations/_matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

PLOT_CLASSES: dict[str, type[core.MPLPlot]] = {
"hist": hist.HistPlot,
"line": core.LinePlot,
"area": core.AreaPlot,
"scatter": core.ScatterPlot,
}


Expand Down
42 changes: 42 additions & 0 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import abc
import typing

import matplotlib.pyplot as plt

Expand All @@ -28,3 +29,44 @@ def draw(self) -> None:
@property
def result(self):
return self.axes


class SamplingPlot(MPLPlot):
@abc.abstractproperty
def _kind(self):
pass

def __init__(self, data, **kwargs) -> None:
self.kwargs = kwargs
self.data = self._compute_plot_data(data)

def generate(self) -> None:
self.axes = self.data.plot(kind=self._kind, **self.kwargs)

def _compute_plot_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", 100)
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
return (
data.sample(n=sampling_n, random_state=sampling_random_state)
.to_pandas()
.sort_index()
)


class LinePlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["line"]:
return "line"


class AreaPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["area"]:
return "area"


class ScatterPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["scatter"]:
return "scatter"
57 changes: 53 additions & 4 deletions bigframes/operations/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence
import typing

import bigframes_vendored.pandas.plotting._core as vendordt

import bigframes.constants as constants
import bigframes.operations._matplotlib as bfplt


class PlotAccessor:
class PlotAccessor(vendordt.PlotAccessor):
__doc__ = vendordt.PlotAccessor.__doc__

def __init__(self, data) -> None:
self._parent = data

def hist(self, by: Optional[Sequence[str]] = None, bins: int = 10, **kwargs):
def hist(
self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs
):
if kwargs.pop("backend", None) is not None:
raise NotImplementedError(
f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}"
)
# Calls matplotlib backend to plot the data.
return bfplt.plot(self._parent.copy(), kind="hist", by=by, bins=bins, **kwargs)

def line(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="line",
x=x,
y=y,
**kwargs,
)

def area(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
stacked: bool = True,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="area",
x=x,
y=y,
stacked=stacked,
**kwargs,
)

def scatter(
self,
x: typing.Optional[typing.Hashable] = None,
y: typing.Optional[typing.Hashable] = None,
s: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None,
**kwargs,
):
return bfplt.plot(
self._parent.copy(),
kind="scatter",
x=x,
y=y,
s=s,
c=c,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas._testing as tm
import pytest

import bigframes.pandas as bpd


def _check_legend_labels(ax, labels):
"""
Expand Down Expand Up @@ -166,3 +169,67 @@ def test_hist_kwargs_ticks_props(scalars_dfs):
for i in range(len(pd_xlables)):
tm.assert_almost_equal(ylabels[i].get_fontsize(), pd_ylables[i].get_fontsize())
tm.assert_almost_equal(ylabels[i].get_rotation(), pd_ylables[i].get_rotation())


def test_line(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
ax = scalars_df[col_names].plot.line()
pd_ax = scalars_pandas_df[col_names].plot.line()
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_area(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too"]
ax = scalars_df[col_names].plot.area(stacked=False)
pd_ax = scalars_pandas_df[col_names].plot.area(stacked=False)
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


def test_scatter(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]
ax = scalars_df[col_names].plot.scatter(x="int64_col", y="float64_col")
pd_ax = scalars_pandas_df[col_names].plot.scatter(x="int64_col", y="float64_col")
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
tm.assert_almost_equal(
ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes()
)


def test_sampling_plot_args_n():
df = bpd.DataFrame(np.arange(1000), columns=["one"])
ax = df.plot.line()
assert len(ax.lines) == 1
# Default sampling_n is 100
assert len(ax.lines[0].get_data()[1]) == 100

ax = df.plot.line(sampling_n=2)
assert len(ax.lines) == 1
assert len(ax.lines[0].get_data()[1]) == 2


def test_sampling_plot_args_random_state():
df = bpd.DataFrame(np.arange(1000), columns=["one"])
ax_0 = df.plot.line()
ax_1 = df.plot.line()
ax_2 = df.plot.line(sampling_random_state=100)
ax_3 = df.plot.line(sampling_random_state=100)

# Setting a fixed sampling_random_state guarantees reproducible plotted sampling.
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_1.lines[0].get_data()[1])
tm.assert_almost_equal(ax_2.lines[0].get_data()[1], ax_3.lines[0].get_data()[1])

msg = "numpy array are different"
with pytest.raises(AssertionError, match=msg):
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])
Loading

0 comments on commit 0772510

Please sign in to comment.