diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 7cbeb3df4f..6f40ab5516 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -14,8 +14,12 @@ import abc import typing +import uuid import matplotlib.pyplot as plt +import pandas as pd + +import bigframes.dtypes as dtypes class MPLPlot(abc.ABC): @@ -38,12 +42,13 @@ def _kind(self): def __init__(self, data, **kwargs) -> None: self.kwargs = kwargs - self.data = self._compute_plot_data(data) + self.data = data def generate(self) -> None: - self.axes = self.data.plot(kind=self._kind, **self.kwargs) + plot_data = self._compute_plot_data() + self.axes = plot_data.plot(kind=self._kind, **self.kwargs) - def _compute_plot_data(self, data): + def _compute_sample_data(self, data): # TODO: Cache the sampling data in the PlotAccessor. sampling_n = self.kwargs.pop("sampling_n", 100) sampling_random_state = self.kwargs.pop("sampling_random_state", 0) @@ -53,6 +58,9 @@ def _compute_plot_data(self, data): sort=False, ).to_pandas() + def _compute_plot_data(self): + return self._compute_sample_data(self.data) + class LinePlot(SamplingPlot): @property @@ -70,3 +78,56 @@ class ScatterPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["scatter"]: return "scatter" + + def __init__(self, data, **kwargs) -> None: + super().__init__(data, **kwargs) + + c = self.kwargs.get("c", None) + if self._is_sequence_arg(c) and len(c) != self.data.shape[0]: + raise ValueError( + f"'c' argument has {len(c)} elements, which is " + + f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}" + ) + + def _compute_plot_data(self): + data = self.data.copy() + + c = self.kwargs.get("c", None) + c_id = None + if self._is_sequence_arg(c): + c_id = self._generate_new_column_name(data) + print(c_id) + data[c_id] = c + + sample = self._compute_sample_data(data) + + # Works around a pandas bug: + # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a + if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: + sample[c] = sample[c].astype("object") + + if c_id is not None: + self.kwargs["c"] = sample[c_id] + sample = sample.drop(columns=[c_id]) + + return sample + + def _is_sequence_arg(self, arg): + return ( + arg is not None + and not isinstance(arg, str) + and isinstance(arg, typing.Iterable) + ) + + def _is_column_name(self, arg, data): + return ( + arg is not None + and pd.core.dtypes.common.is_hashable(arg) + and arg in data.columns + ) + + def _generate_new_column_name(self, data): + col_name = None + while col_name is None or col_name in data.columns: + col_name = f"plot_temp_{str(uuid.uuid4())[:8]}" + return col_name diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 47491cdada..117bcd2f8e 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -208,6 +208,67 @@ def test_scatter(scalars_dfs): ) +@pytest.mark.parametrize( + ("c"), + [ + pytest.param("red", id="red"), + pytest.param("c", id="int_column"), + pytest.param("species", id="color_column"), + pytest.param(["red", "green", "blue"], id="color_sequence"), + pytest.param([3.4, 5.3, 2.0], id="number_sequence"), + pytest.param( + [3.4, 5.3], + id="length_mismatches_sequence", + marks=pytest.mark.xfail( + raises=ValueError, + ), + ), + ], +) +def test_scatter_args_c(c): + data = { + "a": [1, 2, 3], + "b": [1, 2, 3], + "c": [1, 2, 3], + "species": ["r", "g", "b"], + } + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="a", y="b", c=c) + pd_ax = pd_df.plot.scatter(x="a", y="b", c=c) + assert len(ax.collections[0].get_facecolor()) == len( + pd_ax.collections[0].get_facecolor() + ) + for idx in range(len(ax.collections[0].get_facecolor())): + tm.assert_numpy_array_equal( + ax.collections[0].get_facecolor()[idx], + pd_ax.collections[0].get_facecolor()[idx], + ) + + +def test_scatter_args_c_sampling(): + data = { + "plot_temp_0": [1, 2, 3, 4, 5], + "plot_temp_1": [5, 4, 3, 2, 1], + } + c = ["red", "green", "blue", "orange", "black"] + + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3) + pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c) + assert len(ax.collections[0].get_facecolor()) == len( + pd_ax.collections[0].get_facecolor() + ) + for idx in range(len(ax.collections[0].get_facecolor())): + tm.assert_numpy_array_equal( + ax.collections[0].get_facecolor()[idx], + pd_ax.collections[0].get_facecolor()[idx], + ) + + def test_sampling_plot_args_n(): df = bpd.DataFrame(np.arange(1000), columns=["one"]) ax = df.plot.line()