Skip to content

Commit

Permalink
chore(pandas): fix implementation to handle new zero-argument modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 12, 2023
1 parent 0126eda commit b564fcb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/dask/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def test_batting_avg_change_in_games_per_year(players, players_df):


@pytest.mark.xfail(
raises=NotImplementedError,
reason="Grouped and order windows not supported yet",
raises=AssertionError,
reason="Dask doesn't support the `rank` method on SeriesGroupBy",
)
def test_batting_most_hits(players, players_df):
expr = players.mutate(
Expand Down
51 changes: 30 additions & 21 deletions ibis/backends/pandas/execution/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def trim_window_result(data: pd.Series | pd.DataFrame, timecontext: TimeContext
return indexed_subset[name]


@execute_node.register(ops.WindowFunction, pd.Series)
@execute_node.register(ops.WindowFunction, [pd.Series])
def execute_window_op(
op,
data,
*data,
scope: Scope | None = None,
timecontext: TimeContext | None = None,
aggcontext=None,
Expand Down Expand Up @@ -485,33 +485,42 @@ def execute_series_group_by_last_value(op, data, aggcontext=None, **kwargs):
return aggcontext.agg(data, lambda x: _getter(x, -1))


@execute_node.register(ops.MinRank, (pd.Series, SeriesGroupBy))
def execute_series_min_rank(op, data, **kwargs):
# TODO(phillipc): Handle ORDER BY
@execute_node.register(ops.MinRank)
def execute_series_min_rank(op, aggcontext=None, **kwargs):
(key,) = aggcontext.order_by
df = aggcontext.parent
data = df[key]
return data.rank(method="min", ascending=True).astype("int64") - 1


@execute_node.register(ops.DenseRank, (pd.Series, SeriesGroupBy))
def execute_series_dense_rank(op, data, **kwargs):
# TODO(phillipc): Handle ORDER BY
@execute_node.register(ops.DenseRank)
def execute_series_dense_rank(op, aggcontext=None, **kwargs):
(key,) = aggcontext.order_by
df = aggcontext.parent
data = df[key]
return data.rank(method="dense", ascending=True).astype("int64") - 1


@execute_node.register(ops.PercentRank, SeriesGroupBy)
def execute_series_group_by_percent_rank(op, data, **kwargs):
return (
data.rank(method="min", ascending=True)
.sub(1)
.div(data.transform("count").sub(1))
)
@execute_node.register(ops.PercentRank)
def execute_series_group_by_percent_rank(op, aggcontext=None, **kwargs):
(key,) = aggcontext.order_by
df = aggcontext.parent
data = df[key]

result = data.rank(method="min", ascending=True) - 1

if isinstance(data, SeriesGroupBy):
nrows = data.transform("count")
else:
nrows = len(data)

@execute_node.register(ops.PercentRank, pd.Series)
def execute_series_percent_rank(op, data, **kwargs):
# TODO(phillipc): Handle ORDER BY
return data.rank(method="min", ascending=True).sub(1).div(len(data) - 1)
result /= nrows - 1
return result


@execute_node.register(ops.CumeDist, (pd.Series, SeriesGroupBy))
def execute_series_group_by_cume_dist(op, data, **kwargs):
@execute_node.register(ops.CumeDist)
def execute_series_group_by_cume_dist(op, aggcontext=None, **kwargs):
(key,) = aggcontext.order_by
df = aggcontext.parent
data = df[key]
return data.rank(method="min", ascending=True, pct=True)
9 changes: 4 additions & 5 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def calc_zscore(s):
lambda t: t.cumcount(),
id="row_number",
marks=[
pytest.mark.notimpl(
["dask", "pandas"], raises=com.OperationNotDefinedError
)
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.notimpl(["pandas"], raises=com.OperationNotDefinedError),
],
),
param(
Expand Down Expand Up @@ -891,9 +890,9 @@ def gb_fn(df):


@pytest.mark.notimpl(
["clickhouse", "dask", "datafusion", "polars"],
raises=com.OperationNotDefinedError,
["clickhouse", "datafusion", "polars"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["dask"], raises=AttributeError)
@pytest.mark.notimpl(["pyspark"], raises=AnalysisException)
@pytest.mark.notyet(
["clickhouse"],
Expand Down

0 comments on commit b564fcb

Please sign in to comment.