Skip to content

Commit

Permalink
fix: plot.scatter c arguments functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Mar 21, 2024
1 parent 429a4a5 commit 9bc60dd
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 3 deletions.
67 changes: 64 additions & 3 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@

import abc
import typing
import uuid

import matplotlib.pyplot as plt
import pandas as pd

import bigframes.dtypes as dtypes


class MPLPlot(abc.ABC):
Expand All @@ -38,12 +42,13 @@ def _kind(self):

def __init__(self, data, **kwargs) -> None:
self.kwargs = kwargs
self.data = self._compute_plot_data(data)
self.data = data

def generate(self) -> None:
self.axes = self.data.plot(kind=self._kind, **self.kwargs)
plot_data = self._compute_plot_data()
self.axes = plot_data.plot(kind=self._kind, **self.kwargs)

def _compute_plot_data(self, data):
def _compute_sample_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", 100)
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
Expand All @@ -53,6 +58,9 @@ def _compute_plot_data(self, data):
sort=False,
).to_pandas()

def _compute_plot_data(self):
return self._compute_sample_data(self.data)


class LinePlot(SamplingPlot):
@property
Expand All @@ -70,3 +78,56 @@ class ScatterPlot(SamplingPlot):
@property
def _kind(self) -> typing.Literal["scatter"]:
return "scatter"

def __init__(self, data, **kwargs) -> None:
super().__init__(data, **kwargs)

c = self.kwargs.get("c", None)
if self._is_sequence_arg(c) and len(c) != self.data.shape[0]:
raise ValueError(
f"'c' argument has {len(c)} elements, which is "
+ f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}"
)

def _compute_plot_data(self):
data = self.data.copy()

c = self.kwargs.get("c", None)
c_id = None
if self._is_sequence_arg(c):
c_id = self._generate_new_column_name(data)
print(c_id)
data[c_id] = c

sample = self._compute_sample_data(data)

# Works around a pandas bug:
# https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
sample[c] = sample[c].astype("object")

if c_id is not None:
self.kwargs["c"] = sample[c_id]
sample = sample.drop(columns=[c_id])

return sample

def _is_sequence_arg(self, arg):
return (
arg is not None
and not isinstance(arg, str)
and isinstance(arg, typing.Iterable)
)

def _is_column_name(self, arg, data):
return (
arg is not None
and pd.core.dtypes.common.is_hashable(arg)
and arg in data.columns
)

def _generate_new_column_name(self, data):
col_name = None
while col_name is None or col_name in data.columns:
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
return col_name
61 changes: 61 additions & 0 deletions tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,67 @@ def test_scatter(scalars_dfs):
)


@pytest.mark.parametrize(
("c"),
[
pytest.param("red", id="red"),
pytest.param("c", id="int_column"),
pytest.param("species", id="color_column"),
pytest.param(["red", "green", "blue"], id="color_sequence"),
pytest.param([3.4, 5.3, 2.0], id="number_sequence"),
pytest.param(
[3.4, 5.3],
id="length_mismatches_sequence",
marks=pytest.mark.xfail(
raises=ValueError,
),
),
],
)
def test_scatter_args_c(c):
data = {
"a": [1, 2, 3],
"b": [1, 2, 3],
"c": [1, 2, 3],
"species": ["r", "g", "b"],
}
df = bpd.DataFrame(data)
pd_df = pd.DataFrame(data)

ax = df.plot.scatter(x="a", y="b", c=c)
pd_ax = pd_df.plot.scatter(x="a", y="b", c=c)
assert len(ax.collections[0].get_facecolor()) == len(
pd_ax.collections[0].get_facecolor()
)
for idx in range(len(ax.collections[0].get_facecolor())):
tm.assert_numpy_array_equal(
ax.collections[0].get_facecolor()[idx],
pd_ax.collections[0].get_facecolor()[idx],
)


def test_scatter_args_c_sampling():
data = {
"plot_temp_0": [1, 2, 3, 4, 5],
"plot_temp_1": [5, 4, 3, 2, 1],
}
c = ["red", "green", "blue", "orange", "black"]

df = bpd.DataFrame(data)
pd_df = pd.DataFrame(data)

ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3)
pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c)
assert len(ax.collections[0].get_facecolor()) == len(
pd_ax.collections[0].get_facecolor()
)
for idx in range(len(ax.collections[0].get_facecolor())):
tm.assert_numpy_array_equal(
ax.collections[0].get_facecolor()[idx],
pd_ax.collections[0].get_facecolor()[idx],
)


def test_sampling_plot_args_n():
df = bpd.DataFrame(np.arange(1000), columns=["one"])
ax = df.plot.line()
Expand Down

0 comments on commit 9bc60dd

Please sign in to comment.