Skip to content

Commit

Permalink
fix: sampling plot cannot preserve ordering if index is not ordered (#…
Browse files Browse the repository at this point in the history
…475)

* fix: sampling plot cannot preserve ordering if index is not ordered

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* change sort type

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
chelsea-lin and gcf-owl-bot[bot] authored Mar 20, 2024
1 parent 43d0864 commit a5345fe
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 13 deletions.
19 changes: 15 additions & 4 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import itertools
import random
import typing
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple
import warnings

import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -555,7 +555,7 @@ def _downsample(
block = self._split(
fracs=(fraction,),
random_state=random_state,
preserve_order=True,
sort=False,
)[0]
return block
else:
Expand All @@ -571,7 +571,7 @@ def _split(
fracs: Iterable[float] = (),
*,
random_state: Optional[int] = None,
preserve_order: Optional[bool] = False,
sort: Optional[bool | Literal["random"]] = "random",
) -> List[Block]:
"""Internal function to support splitting Block to multiple parts along index axis.
Expand Down Expand Up @@ -623,7 +623,18 @@ def _split(
typing.cast(Block, block.slice(start=lower, stop=upper))
for lower, upper in intervals
]
if preserve_order:

if sort is True:
sliced_blocks = [
sliced_block.order_by(
[
ordering.OrderingColumnReference(idx_col)
for idx_col in sliced_block.index_columns
]
)
for sliced_block in sliced_blocks
]
elif sort is False:
sliced_blocks = [
sliced_block.order_by([ordering.OrderingColumnReference(ordering_col)])
for sliced_block in sliced_blocks
Expand Down
5 changes: 4 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,14 +2504,17 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
) -> DataFrame:
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' parameter can be specified.")

ns = (n,) if n is not None else ()
fracs = (frac,) if frac is not None else ()
return DataFrame(
self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0]
self._block._split(
ns=ns, fracs=fracs, random_state=random_state, sort=sort
)[0]
)

def _split(
Expand Down
10 changes: 5 additions & 5 deletions bigframes/operations/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def _compute_plot_data(self, data):
# TODO: Cache the sampling data in the PlotAccessor.
sampling_n = self.kwargs.pop("sampling_n", 100)
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
return (
data.sample(n=sampling_n, random_state=sampling_random_state)
.to_pandas()
.sort_index()
)
return data.sample(
n=sampling_n,
random_state=sampling_random_state,
sort=False,
).to_pandas()


class LinePlot(SamplingPlot):
Expand Down
7 changes: 5 additions & 2 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import textwrap
import typing
from typing import Any, Mapping, Optional, Tuple, Union
from typing import Any, Literal, Mapping, Optional, Tuple, Union

import bigframes_vendored.pandas.core.series as vendored_pandas_series
import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -1535,14 +1535,17 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
) -> Series:
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' parameter can be specified.")

ns = (n,) if n is not None else ()
fracs = (frac,) if frac is not None else ()
return Series(
self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0]
self._block._split(
ns=ns, fracs=fracs, random_state=random_state, sort=sort
)[0]
)

def __array_ufunc__(
Expand Down
15 changes: 14 additions & 1 deletion tests/system/small/operations/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pandas as pd
import pandas._testing as tm
import pytest

Expand Down Expand Up @@ -235,6 +236,18 @@ def test_sampling_plot_args_random_state():
tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1])


def test_sampling_preserve_ordering():
df = bpd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0])
pd_df = pd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0])
ax = df.plot.line()
pd_ax = pd_df.plot.line()
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
for line, pd_line in zip(ax.lines, pd_ax.lines):
# Compare y coordinates between the lines
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])


@pytest.mark.parametrize(
("kind", "col_names", "kwargs"),
[
Expand All @@ -251,7 +264,7 @@ def test_sampling_plot_args_random_state():
marks=pytest.mark.xfail(raises=ValueError),
),
pytest.param(
"uknown",
"bar",
["int64_col", "int64_too"],
{},
marks=pytest.mark.xfail(raises=NotImplementedError),
Expand Down
22 changes: 22 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3049,6 +3049,28 @@ def test_sample_raises_value_error(scalars_dfs):
scalars_df.sample(frac=0.5, n=4)


def test_sample_args_sort(scalars_dfs):
scalars_df, _ = scalars_dfs
index = [4, 3, 2, 5, 1, 0]
scalars_df = scalars_df.iloc[index]

kwargs = {"frac": 1.0, "random_state": 333}

df = scalars_df.sample(**kwargs).to_pandas()
assert df.index.values != index
assert df.index.values != sorted(index)

df = scalars_df.sample(sort="random", **kwargs).to_pandas()
assert df.index.values != index
assert df.index.values != sorted(index)

df = scalars_df.sample(sort=True, **kwargs).to_pandas()
assert df.index.values == sorted(index)

df = scalars_df.sample(sort=False, **kwargs).to_pandas()
assert df.index.values == index


@pytest.mark.parametrize(
("axis",),
[
Expand Down
7 changes: 7 additions & 0 deletions third_party/bigframes_vendored/pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def sample(
frac: Optional[float] = None,
*,
random_state: Optional[int] = None,
sort: Optional[bool | Literal["random"]] = "random",
):
"""Return a random sample of items from an axis of object.
Expand Down Expand Up @@ -530,6 +531,12 @@ def sample(
Fraction of axis items to return. Cannot be used with `n`.
random_state (Optional[int], default None):
Seed for random number generator.
sort (Optional[bool|Literal["random"]], default "random"):
- 'random' (default): No specific ordering will be applied after
sampling.
- 'True' : Index columns will determine the sample's order.
- 'False': The sample will retain the original object's order.
Returns:
A new object of same type as caller containing `n` items randomly
Expand Down

0 comments on commit a5345fe

Please sign in to comment.