Skip to content

Commit

Permalink
fix: plot.scatter c argument functionalities (#494)
Browse files Browse the repository at this point in the history
Fixes internal bug: b/330770901 🦕
  • Loading branch information
chelsea-lin authored Mar 24, 2024
1 parent 65c6f47 commit d6ee994
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
58 changes: 55 additions & 3 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

import abc
import typing
import uuid

import pandas as pd

import bigframes.constants as constants
import bigframes.dtypes as dtypes

DEFAULT_SAMPLING_N = 1000
DEFAULT_SAMPLING_STATE = 0
Expand Down Expand Up @@ -44,12 +50,13 @@ def _kind(self):

def __init__(self, data, **kwargs) -> None:
self.kwargs = kwargs
self.data = self._compute_plot_data(data)
self.data = data

def generate(self) -> None:
self.axes = self.data.plot(kind=self._kind, **self.kwargs)
plot_data = self._compute_plot_data()
self.axes = plot_data.plot(kind=self._kind, **self.kwargs)

def _compute_plot_data(self, data):
def _compute_sample_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N)
sampling_random_state = self.kwargs.pop(
Expand All @@ -61,6 +68,9 @@ def _compute_plot_data(self, data):
sort=False,
).to_pandas()

def _compute_plot_data(self):
return self._compute_sample_data(self.data)


class LinePlot(SamplingPlot):
@property
Expand All @@ -78,3 +88,45 @@ class ScatterPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["scatter"]:
return "scatter"

def __init__(self, data, **kwargs) -> None:
super().__init__(data, **kwargs)

c = self.kwargs.get("c", None)
if self._is_sequence_arg(c):
raise NotImplementedError(
f"Only support a single color string or a column name/posision. {constants.FEEDBACK_LINK}"
)

def _compute_plot_data(self):
sample = self._compute_sample_data(self.data)

# Works around a pandas bug:
# https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
c = self.kwargs.get("c", None)
if pd.core.dtypes.common.is_integer(c):
c = self.data.columns[c]
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
sample[c] = sample[c].astype("object")

return sample

def _is_sequence_arg(self, arg):
return (
arg is not None
and not isinstance(arg, str)
and isinstance(arg, typing.Iterable)
)

def _is_column_name(self, arg, data):
return (
arg is not None
and pd.core.dtypes.common.is_hashable(arg)
and arg in data.columns
)

def _generate_new_column_name(self, data):
col_name = None
while col_name is None or col_name in data.columns:
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
return col_name
31 changes: 31 additions & 0 deletions tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,37 @@ def test_scatter(scalars_dfs):
)


@pytest.mark.parametrize(
("c"),
[
pytest.param("red", id="red"),
pytest.param("c", id="int_column"),
pytest.param("species", id="color_column"),
pytest.param(3, id="column_index"),
],
)
def test_scatter_args_c(c):
data = {
"a": [1, 2, 3],
"b": [1, 2, 3],
"c": [1, 2, 3],
"species": ["r", "g", "b"],
}
df = bpd.DataFrame(data)
pd_df = pd.DataFrame(data)

ax = df.plot.scatter(x="a", y="b", c=c)
pd_ax = pd_df.plot.scatter(x="a", y="b", c=c)
assert len(ax.collections[0].get_facecolor()) == len(
pd_ax.collections[0].get_facecolor()
)
for idx in range(len(ax.collections[0].get_facecolor())):
tm.assert_numpy_array_equal(
ax.collections[0].get_facecolor()[idx],
pd_ax.collections[0].get_facecolor()[idx],
)


def test_sampling_plot_args_n():
df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"])
ax = df.plot.line()
Expand Down
4 changes: 0 additions & 4 deletions third_party/bigframes_vendored/pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,6 @@ def scatter(
- A single color string referred to by name, RGB or RGBA code,
for instance 'red' or '#a98d19'.
- A sequence of color strings referred to by name, RGB or RGBA
code, which will be used for each point's color recursively. For
instance ['green','yellow'] all points will be filled in green or
yellow, alternatively.
- A column name or position whose values will be used to color the
marker points according to a colormap.
Expand Down

0 comments on commit d6ee994

Please sign in to comment.