diff --git a/bigframes/operations/_matplotlib/__init__.py b/bigframes/operations/_matplotlib/__init__.py index f8770a9ef8..02aca8cf5d 100644 --- a/bigframes/operations/_matplotlib/__init__.py +++ b/bigframes/operations/_matplotlib/__init__.py @@ -17,6 +17,9 @@ PLOT_CLASSES: dict[str, type[core.MPLPlot]] = { "hist": hist.HistPlot, + "line": core.LinePlot, + "area": core.AreaPlot, + "scatter": core.ScatterPlot, } diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 4b15d6f4dd..5c9d771f61 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +import typing import matplotlib.pyplot as plt @@ -28,3 +29,44 @@ def draw(self) -> None: @property def result(self): return self.axes + + +class SamplingPlot(MPLPlot): + @abc.abstractproperty + def _kind(self): + pass + + def __init__(self, data, **kwargs) -> None: + self.kwargs = kwargs + self.data = self._compute_plot_data(data) + + def generate(self) -> None: + self.axes = self.data.plot(kind=self._kind, **self.kwargs) + + def _compute_plot_data(self, data): + # TODO: Cache the sampling data in the PlotAccessor. + sampling_n = self.kwargs.pop("sampling_n", 100) + sampling_random_state = self.kwargs.pop("sampling_random_state", 0) + return ( + data.sample(n=sampling_n, random_state=sampling_random_state) + .to_pandas() + .sort_index() + ) + + +class LinePlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["line"]: + return "line" + + +class AreaPlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["area"]: + return "area" + + +class ScatterPlot(SamplingPlot): + @property + def _kind(self) -> typing.Literal["scatter"]: + return "scatter" diff --git a/bigframes/operations/plotting.py b/bigframes/operations/plotting.py index d19485e65e..cc9f71e5d1 100644 --- a/bigframes/operations/plotting.py +++ b/bigframes/operations/plotting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence +import typing import bigframes_vendored.pandas.plotting._core as vendordt @@ -20,16 +20,65 @@ import bigframes.operations._matplotlib as bfplt -class PlotAccessor: +class PlotAccessor(vendordt.PlotAccessor): __doc__ = vendordt.PlotAccessor.__doc__ def __init__(self, data) -> None: self._parent = data - def hist(self, by: Optional[Sequence[str]] = None, bins: int = 10, **kwargs): + def hist( + self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs + ): if kwargs.pop("backend", None) is not None: raise NotImplementedError( f"Only support matplotlib backend for now. {constants.FEEDBACK_LINK}" ) - # Calls matplotlib backend to plot the data. return bfplt.plot(self._parent.copy(), kind="hist", by=by, bins=bins, **kwargs) + + def line( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + return bfplt.plot( + self._parent.copy(), + kind="line", + x=x, + y=y, + **kwargs, + ) + + def area( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + stacked: bool = True, + **kwargs, + ): + return bfplt.plot( + self._parent.copy(), + kind="area", + x=x, + y=y, + stacked=stacked, + **kwargs, + ) + + def scatter( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + s: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None, + c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None, + **kwargs, + ): + return bfplt.plot( + self._parent.copy(), + kind="scatter", + x=x, + y=y, + s=s, + c=c, + **kwargs, + ) diff --git a/tests/system/small/operations/test_plot.py b/tests/system/small/operations/test_plotting.py similarity index 69% rename from tests/system/small/operations/test_plot.py rename to tests/system/small/operations/test_plotting.py index 44f31ec071..ce320b6f57 100644 --- a/tests/system/small/operations/test_plot.py +++ b/tests/system/small/operations/test_plotting.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pandas._testing as tm import pytest +import bigframes.pandas as bpd + def _check_legend_labels(ax, labels): """ @@ -166,3 +169,67 @@ def test_hist_kwargs_ticks_props(scalars_dfs): for i in range(len(pd_xlables)): tm.assert_almost_equal(ylabels[i].get_fontsize(), pd_ylables[i].get_fontsize()) tm.assert_almost_equal(ylabels[i].get_rotation(), pd_ylables[i].get_rotation()) + + +def test_line(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "float64_col", "int64_too", "bool_col"] + ax = scalars_df[col_names].plot.line() + pd_ax = scalars_pandas_df[col_names].plot.line() + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + for line, pd_line in zip(ax.lines, pd_ax.lines): + # Compare y coordinates between the lines + tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) + + +def test_area(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "float64_col", "int64_too"] + ax = scalars_df[col_names].plot.area(stacked=False) + pd_ax = scalars_pandas_df[col_names].plot.area(stacked=False) + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + for line, pd_line in zip(ax.lines, pd_ax.lines): + # Compare y coordinates between the lines + tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) + + +def test_scatter(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "float64_col", "int64_too", "bool_col"] + ax = scalars_df[col_names].plot.scatter(x="int64_col", y="float64_col") + pd_ax = scalars_pandas_df[col_names].plot.scatter(x="int64_col", y="float64_col") + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + tm.assert_almost_equal( + ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes() + ) + + +def test_sampling_plot_args_n(): + df = bpd.DataFrame(np.arange(1000), columns=["one"]) + ax = df.plot.line() + assert len(ax.lines) == 1 + # Default sampling_n is 100 + assert len(ax.lines[0].get_data()[1]) == 100 + + ax = df.plot.line(sampling_n=2) + assert len(ax.lines) == 1 + assert len(ax.lines[0].get_data()[1]) == 2 + + +def test_sampling_plot_args_random_state(): + df = bpd.DataFrame(np.arange(1000), columns=["one"]) + ax_0 = df.plot.line() + ax_1 = df.plot.line() + ax_2 = df.plot.line(sampling_random_state=100) + ax_3 = df.plot.line(sampling_random_state=100) + + # Setting a fixed sampling_random_state guarantees reproducible plotted sampling. + tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_1.lines[0].get_data()[1]) + tm.assert_almost_equal(ax_2.lines[0].get_data()[1], ax_3.lines[0].get_data()[1]) + + msg = "numpy array are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1]) diff --git a/third_party/bigframes_vendored/pandas/plotting/_core.py b/third_party/bigframes_vendored/pandas/plotting/_core.py index d0425737ee..2b0f077695 100644 --- a/third_party/bigframes_vendored/pandas/plotting/_core.py +++ b/third_party/bigframes_vendored/pandas/plotting/_core.py @@ -1,14 +1,14 @@ -from typing import Optional, Sequence +import typing from bigframes import constants class PlotAccessor: - """ - Make plots of Series or DataFrame with the `matplotlib` backend. - """ + """Make plots of Series or DataFrame with the `matplotlib` backend.""" - def hist(self, by: Optional[Sequence[str]] = None, bins: int = 10, **kwargs): + def hist( + self, by: typing.Optional[typing.Sequence[str]] = None, bins: int = 10, **kwargs + ): """ Draw one histogram of the DataFrame’s columns. @@ -17,32 +17,237 @@ def hist(self, by: Optional[Sequence[str]] = None, bins: int = 10, **kwargs): into bins and draws all bins in one :class:`matplotlib.axes.Axes`. This is useful when the DataFrame's Series are in a similar scale. - Parameters - ---------- - by : str or sequence, optional - Column in the DataFrame to group by. It is not supported yet. - bins : int, default 10 - Number of histogram bins to be used. - **kwargs - Additional keyword arguments are documented in - :meth:`DataFrame.plot`. - - Returns - ------- - class:`matplotlib.AxesSubplot` - Return a histogram plot. - - Examples - -------- - For Series: - - .. plot:: - :context: close-figs + **Examples:** >>> import bigframes.pandas as bpd >>> import numpy as np >>> df = bpd.DataFrame(np.random.randint(1, 7, 6000), columns=['one']) >>> df['two'] = np.random.randint(1, 7, 6000) + np.random.randint(1, 7, 6000) >>> ax = df.plot.hist(bins=12, alpha=0.5) + + Args: + by (str or sequence, optional): + Column in the DataFrame to group by. It is not supported yet. + bins (int, default 10): + Number of histogram bins to be used. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + class:`matplotlib.AxesSubplot`: A histogram plot. + + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + + def line( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + **kwargs, + ): + """ + Plot Series or DataFrame as lines. This function is useful to plot lines + using DataFrame's values as coordinates. + + This function calls `pandas.plot` to generate a plot with a random sample + of items. For consistent results, the random sampling is reproducible. + Use the `sampling_random_state` parameter to modify the sampling seed. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> df = bpd.DataFrame( + ... { + ... 'one': [1, 2, 3, 4], + ... 'three': [3, 6, 9, 12], + ... 'reverse_ten': [40, 30, 20, 10], + ... } + ... ) + >>> ax = df.plot.line(x='one') + + Args: + x (label or position, optional): + Allows plotting of one column versus another. If not specified, + the index of the DataFrame is used. + y (label or position, optional): + Allows plotting of one column versus another. If not specified, + all numerical columns are used. + color (str, array-like, or dict, optional): + The color for each of the DataFrame's columns. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each column recursively. For + instance ['green','yellow'] each column's %(kind)s will be filled in + green or yellow, alternatively. If there is only a single column to + be plotted, then only the first color from the color list will be + used. + + - A dict of the form {column name : color}, so that each column will be + colored accordingly. For example, if your columns are called `a` and + `b`, then passing {'a': 'green', 'b': 'red'} will color %(kind)ss for + column `a` in green and %(kind)ss for column `b` in red. + sampling_n (int, default 100): + Number of random items for plotting. + sampling_random_state (int, default 0): + Seed for random number generator. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + matplotlib.axes.Axes or np.ndarray of them: + An ndarray is returned with one :class:`matplotlib.axes.Axes` + per column when ``subplots=True``. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + + def area( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + stacked: bool = True, + **kwargs, + ): + """ + Draw a stacked area plot. An area plot displays quantitative data visually. + + This function calls `pandas.plot` to generate a plot with a random sample + of items. For consistent results, the random sampling is reproducible. + Use the `sampling_random_state` parameter to modify the sampling seed. + + **Examples:** + + Draw an area plot based on basic business metrics: + + >>> import bigframes.pandas as bpd + >>> df = bpd.DataFrame( + ... { + ... 'sales': [3, 2, 3, 9, 10, 6], + ... 'signups': [5, 5, 6, 12, 14, 13], + ... 'visits': [20, 42, 28, 62, 81, 50], + ... }, + ... index=["01-31", "02-28", "03-31", "04-30", "05-31", "06-30"] + ... ) + >>> ax = df.plot.area() + + Area plots are stacked by default. To produce an unstacked plot, + pass ``stacked=False``: + + >>> ax = df.plot.area(stacked=False) + + Draw an area plot for a single column: + + >>> ax = df.plot.area(y='sales') + + Draw with a different `x`: + + >>> df = bpd.DataFrame({ + ... 'sales': [3, 2, 3], + ... 'visits': [20, 42, 28], + ... 'day': [1, 2, 3], + ... }) + >>> ax = df.plot.area(x='day') + + Args: + x (label or position, optional): + Coordinates for the X axis. By default uses the index. + y (label or position, optional): + Column to plot. By default uses all columns. + stacked (bool, default True): + Area plots are stacked by default. Set to False to create a + unstacked plot. + sampling_n (int, default 100): + Number of random items for plotting. + sampling_random_state (int, default 0): + Seed for random number generator. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + matplotlib.axes.Axes or numpy.ndarray: + Area plot, or array of area plots if subplots is True. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + + def scatter( + self, + x: typing.Optional[typing.Hashable] = None, + y: typing.Optional[typing.Hashable] = None, + s: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None, + c: typing.Union[typing.Hashable, typing.Sequence[typing.Hashable]] = None, + **kwargs, + ): + """ + Create a scatter plot with varying marker point size and color. + + This function calls `pandas.plot` to generate a plot with a random sample + of items. For consistent results, the random sampling is reproducible. + Use the `sampling_random_state` parameter to modify the sampling seed. + + **Examples:** + + Let's see how to draw a scatter plot using coordinates from the values + in a DataFrame's columns. + + >>> import bigframes.pandas as bpd + >>> df = bpd.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], + ... [6.4, 3.2, 1], [5.9, 3.0, 2]], + ... columns=['length', 'width', 'species']) + >>> ax1 = df.plot.scatter(x='length', + ... y='width', + ... c='DarkBlue') + + And now with the color determined by a column as well. + + >>> ax2 = df.plot.scatter(x='length', + ... y='width', + ... c='species', + ... colormap='viridis') + + Args: + x (int or str): + The column name or column position to be used as horizontal + coordinates for each point. + y (int or str): + The column name or column position to be used as vertical + coordinates for each point. + s (str, scalar or array-like, optional): + The size of each point. Possible values are: + + - A string with the name of the column to be used for marker's size. + - A single scalar so all points have the same size. + - A sequence of scalars, which will be used for each point's size + recursively. For instance, when passing [2,14] all points size + will be either 2 or 14, alternatively. + + c (str, int or array-like, optional): + The color of each point. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each point's color recursively. For + instance ['green','yellow'] all points will be filled in green or + yellow, alternatively. + - A column name or position whose values will be used to color the + marker points according to a colormap. + + sampling_n (int, default 100): + Number of random items for plotting. + sampling_random_state (int, default 0): + Seed for random number generator. + **kwargs: + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns: + matplotlib.axes.Axes or np.ndarray of them: + An ndarray is returned with one :class:`matplotlib.axes.Axes` + per column when ``subplots=True``. """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)