Skip to content

Commit

Permalink
fix: raise on invalid sort_by group lengths (#11423)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 29, 2023
1 parent fc89548 commit 6f0705b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 35 deletions.
81 changes: 46 additions & 35 deletions crates/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ fn prepare_descending(descending: &[bool], by_len: usize) -> Vec<bool> {
}
}

fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> {
polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| {
a.len() == b.len()
}), ComputeError: "expressions in 'sort_by' produced a different number of groups");
Ok(())
}

impl PhysicalExpr for SortByExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
Expand Down Expand Up @@ -161,43 +168,47 @@ impl PhysicalExpr for SortByExpr {
);
let groups = ac_sort_by.groups();

let groups = POOL.install(|| {
groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// SAFETY: group tuples are always in bounds.
let group = unsafe { sort_by_s.take_slice_unchecked(idx) };
let (check, groups) = POOL.join(
|| check_groups(groups, ac_in.groups()),
|| {
groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// SAFETY: group tuples are always in bounds.
let group = unsafe { sort_by_s.take_slice_unchecked(idx) };

let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
},
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
},
};
let first = new_idx.first().unwrap_or_else(|| {
invalid.store(true, Ordering::Relaxed);
&0
});
let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
},
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
},
};
let first = new_idx.first().unwrap_or_else(|| {
invalid.store(true, Ordering::Relaxed);
&0
});

(*first, new_idx)
})
.collect()
});
(*first, new_idx)
})
.collect()
},
);
check?;

(GroupsProxy::Idx(groups), ordered_by_group_operation)
} else {
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,22 @@ def test_empty_inputs_error() -> None:
pl.ComputeError, match="expression: 'fold' didn't get any inputs"
):
df.select(pl.sum_horizontal(pl.exclude("col1")))


def test_sort_by_error() -> None:
df = pl.DataFrame(
{
"id": [1, 1, 1, 2, 2, 3, 3, 3],
"number": [1, 3, 2, 1, 2, 2, 1, 3],
"type": ["A", "B", "A", "B", "B", "A", "B", "C"],
"cost": [10, 25, 20, 25, 30, 30, 50, 100],
}
)

with pytest.raises(
pl.ComputeError,
match="expressions in 'sort_by' produced a different number of groups",
):
df.group_by("id", maintain_order=True).agg(
pl.col("cost").filter(pl.col("type") == "A").sort_by("number")
)

0 comments on commit 6f0705b

Please sign in to comment.