diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index e7503e4d6995..a59c13ef3d81 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -43,6 +43,13 @@ fn prepare_descending(descending: &[bool], by_len: usize) -> Vec { } } +fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> { + polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| { + a.len() == b.len() + }), ComputeError: "expressions in 'sort_by' produced a different number of groups"); + Ok(()) +} + impl PhysicalExpr for SortByExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) @@ -161,43 +168,47 @@ impl PhysicalExpr for SortByExpr { ); let groups = ac_sort_by.groups(); - let groups = POOL.install(|| { - groups - .par_iter() - .map(|indicator| { - let new_idx = match indicator { - GroupsIndicator::Idx((_, idx)) => { - // SAFETY: group tuples are always in bounds. - let group = unsafe { sort_by_s.take_slice_unchecked(idx) }; + let (check, groups) = POOL.join( + || check_groups(groups, ac_in.groups()), + || { + groups + .par_iter() + .map(|indicator| { + let new_idx = match indicator { + GroupsIndicator::Idx((_, idx)) => { + // SAFETY: group tuples are always in bounds. + let group = unsafe { sort_by_s.take_slice_unchecked(idx) }; - let sorted_idx = group.arg_sort(SortOptions { - descending: descending[0], - // We are already in par iter. - multithreaded: false, - ..Default::default() - }); - map_sorted_indices_to_group_idx(&sorted_idx, idx) - }, - GroupsIndicator::Slice([first, len]) => { - let group = sort_by_s.slice(first as i64, len as usize); - let sorted_idx = group.arg_sort(SortOptions { - descending: descending[0], - // We are already in par iter. - multithreaded: false, - ..Default::default() - }); - map_sorted_indices_to_group_slice(&sorted_idx, first) - }, - }; - let first = new_idx.first().unwrap_or_else(|| { - invalid.store(true, Ordering::Relaxed); - &0 - }); + let sorted_idx = group.arg_sort(SortOptions { + descending: descending[0], + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + }, + GroupsIndicator::Slice([first, len]) => { + let group = sort_by_s.slice(first as i64, len as usize); + let sorted_idx = group.arg_sort(SortOptions { + descending: descending[0], + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + map_sorted_indices_to_group_slice(&sorted_idx, first) + }, + }; + let first = new_idx.first().unwrap_or_else(|| { + invalid.store(true, Ordering::Relaxed); + &0 + }); - (*first, new_idx) - }) - .collect() - }); + (*first, new_idx) + }) + .collect() + }, + ); + check?; (GroupsProxy::Idx(groups), ordered_by_group_operation) } else { diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 2b54218b3e29..846d4074495f 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -691,3 +691,22 @@ def test_empty_inputs_error() -> None: pl.ComputeError, match="expression: 'fold' didn't get any inputs" ): df.select(pl.sum_horizontal(pl.exclude("col1"))) + + +def test_sort_by_error() -> None: + df = pl.DataFrame( + { + "id": [1, 1, 1, 2, 2, 3, 3, 3], + "number": [1, 3, 2, 1, 2, 2, 1, 3], + "type": ["A", "B", "A", "B", "B", "A", "B", "C"], + "cost": [10, 25, 20, 25, 30, 30, 50, 100], + } + ) + + with pytest.raises( + pl.ComputeError, + match="expressions in 'sort_by' produced a different number of groups", + ): + df.group_by("id", maintain_order=True).agg( + pl.col("cost").filter(pl.col("type") == "A").sort_by("number") + )