diff --git a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs index a9f3c5662297..487c6da9d803 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -140,12 +140,6 @@ impl PhysicalExpr for TernaryExpr { state: &ExecutionState, ) -> PolarsResult> { let aggregation_predicate = self.predicate.is_valid_aggregation(); - if !aggregation_predicate { - // Unwrap will not fail as it is not an aggregation expression. - eprintln!( - "The predicate '{}' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the group_by operation would. This behavior is experimental and may be subject to change", self.predicate.as_expression().unwrap() - ) - } let op_mask = || self.predicate.evaluate_on_groups(df, groups, state); let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state); @@ -197,7 +191,6 @@ impl PhysicalExpr for TernaryExpr { // None (AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => { if !aggregation_predicate { - // Experimental elementwise behavior tested in `test_binary_agg_context_1`. return finish_as_iters(ac_truthy, ac_falsy, ac_mask); } let mask = mask_s.bool()?; @@ -299,7 +292,6 @@ impl PhysicalExpr for TernaryExpr { } if !aggregation_predicate { - // Experimental elementwise behavior tested in `test_binary_agg_context_1`. return finish_as_iters(ac_truthy, ac_falsy, ac_mask); } let mut mask = mask_s.bool()?.clone(); diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 786f3d72d0cb..9098a3893ae4 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -870,3 +870,20 @@ def test_group_by_double_on_empty_12194() -> None: assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict( [("group", pl.Int64), ("x", pl.Float64)] ) + + +def test_group_by_when_then_no_aggregation_predicate() -> None: + df = pl.DataFrame( + { + "key": ["aa", "aa", "bb", "bb", "aa", "aa"], + "val": [-3, -2, 1, 4, -3, 5], + } + ) + assert df.group_by("key").agg( + pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(), + neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(), + ).sort("key").to_dict(as_series=False) == { + "key": ["aa", "bb"], + "pos": [5, 5], + "neg": [-8, 0], + }