Skip to content

Commit

Permalink
feat: allow non-aggregation predicate in ternary groupby (pola-rs#12286)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 7, 2023
1 parent 450fbab commit 817dcae
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
8 changes: 0 additions & 8 deletions crates/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,6 @@ impl PhysicalExpr for TernaryExpr {
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let aggregation_predicate = self.predicate.is_valid_aggregation();
if !aggregation_predicate {
// Unwrap will not fail as it is not an aggregation expression.
eprintln!(
"The predicate '{}' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the group_by operation would. This behavior is experimental and may be subject to change", self.predicate.as_expression().unwrap()
)
}

let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
Expand Down Expand Up @@ -197,7 +191,6 @@ impl PhysicalExpr for TernaryExpr {
// None
(AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => {
if !aggregation_predicate {
// Experimental elementwise behavior tested in `test_binary_agg_context_1`.
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
let mask = mask_s.bool()?;
Expand Down Expand Up @@ -299,7 +292,6 @@ impl PhysicalExpr for TernaryExpr {
}

if !aggregation_predicate {
// Experimental elementwise behavior tested in `test_binary_agg_context_1`.
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
let mut mask = mask_s.bool()?.clone();
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,20 @@ def test_group_by_double_on_empty_12194() -> None:
assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict(
[("group", pl.Int64), ("x", pl.Float64)]
)


def test_group_by_when_then_no_aggregation_predicate() -> None:
df = pl.DataFrame(
{
"key": ["aa", "aa", "bb", "bb", "aa", "aa"],
"val": [-3, -2, 1, 4, -3, 5],
}
)
assert df.group_by("key").agg(
pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(),
neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(),
).sort("key").to_dict(as_series=False) == {
"key": ["aa", "bb"],
"pos": [5, 5],
"neg": [-8, 0],
}

0 comments on commit 817dcae

Please sign in to comment.