diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 663e7a789f..ad5abb4bca 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -14,6 +14,12 @@ import abc import typing +import uuid + +import pandas as pd + +import bigframes.constants as constants +import bigframes.dtypes as dtypes DEFAULT_SAMPLING_N = 1000 DEFAULT_SAMPLING_STATE = 0 @@ -44,12 +50,13 @@ def _kind(self): def __init__(self, data, **kwargs) -> None: self.kwargs = kwargs - self.data = self._compute_plot_data(data) + self.data = data def generate(self) -> None: - self.axes = self.data.plot(kind=self._kind, **self.kwargs) + plot_data = self._compute_plot_data() + self.axes = plot_data.plot(kind=self._kind, **self.kwargs) - def _compute_plot_data(self, data): + def _compute_sample_data(self, data): # TODO: Cache the sampling data in the PlotAccessor. sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N) sampling_random_state = self.kwargs.pop( @@ -61,6 +68,9 @@ def _compute_plot_data(self, data): sort=False, ).to_pandas() + def _compute_plot_data(self): + return self._compute_sample_data(self.data) + class LinePlot(SamplingPlot): @property @@ -78,3 +88,45 @@ class ScatterPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["scatter"]: return "scatter" + + def __init__(self, data, **kwargs) -> None: + super().__init__(data, **kwargs) + + c = self.kwargs.get("c", None) + if self._is_sequence_arg(c): + raise NotImplementedError( + f"Only support a single color string or a column name/posision. {constants.FEEDBACK_LINK}" + ) + + def _compute_plot_data(self): + sample = self._compute_sample_data(self.data) + + # Works around a pandas bug: + # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a + c = self.kwargs.get("c", None) + if pd.core.dtypes.common.is_integer(c): + c = self.data.columns[c] + if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: + sample[c] = sample[c].astype("object") + + return sample + + def _is_sequence_arg(self, arg): + return ( + arg is not None + and not isinstance(arg, str) + and isinstance(arg, typing.Iterable) + ) + + def _is_column_name(self, arg, data): + return ( + arg is not None + and pd.core.dtypes.common.is_hashable(arg) + and arg in data.columns + ) + + def _generate_new_column_name(self, data): + col_name = None + while col_name is None or col_name in data.columns: + col_name = f"plot_temp_{str(uuid.uuid4())[:8]}" + return col_name diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 5ca3382e2a..41ea7d4ebb 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -209,6 +209,37 @@ def test_scatter(scalars_dfs): ) +@pytest.mark.parametrize( + ("c"), + [ + pytest.param("red", id="red"), + pytest.param("c", id="int_column"), + pytest.param("species", id="color_column"), + pytest.param(3, id="column_index"), + ], +) +def test_scatter_args_c(c): + data = { + "a": [1, 2, 3], + "b": [1, 2, 3], + "c": [1, 2, 3], + "species": ["r", "g", "b"], + } + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="a", y="b", c=c) + pd_ax = pd_df.plot.scatter(x="a", y="b", c=c) + assert len(ax.collections[0].get_facecolor()) == len( + pd_ax.collections[0].get_facecolor() + ) + for idx in range(len(ax.collections[0].get_facecolor())): + tm.assert_numpy_array_equal( + ax.collections[0].get_facecolor()[idx], + pd_ax.collections[0].get_facecolor()[idx], + ) + + def test_sampling_plot_args_n(): df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"]) ax = df.plot.line() diff --git a/third_party/bigframes_vendored/pandas/plotting/_core.py b/third_party/bigframes_vendored/pandas/plotting/_core.py index d901f41ef8..f8da9efdc0 100644 --- a/third_party/bigframes_vendored/pandas/plotting/_core.py +++ b/third_party/bigframes_vendored/pandas/plotting/_core.py @@ -266,10 +266,6 @@ def scatter( - A single color string referred to by name, RGB or RGBA code, for instance 'red' or '#a98d19'. - - A sequence of color strings referred to by name, RGB or RGBA - code, which will be used for each point's color recursively. For - instance ['green','yellow'] all points will be filled in green or - yellow, alternatively. - A column name or position whose values will be used to color the marker points according to a colormap.