Skip to content

Commit

Permalink
fix: Properly broadcast in sort_by (pola-rs#20434)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Jan 8, 2025
1 parent 7ddbf3a commit 044bc51
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
55 changes: 40 additions & 15 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,58 @@ impl PhysicalExpr for SortByExpr {
let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());

let sorted_idx_f = || {
let s_sort_by = self
let mut needs_broadcast = false;
let mut broadcast_length = 1;

let mut s_sort_by = self
.by
.iter()
.map(|e| {
e.evaluate(df, state).map(|s| match s.dtype() {
.enumerate()
.map(|(i, e)| {
let column = e.evaluate(df, state).map(|c| match c.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_, _) | DataType::Enum(_, _) => s,
_ => s.to_physical_repr(),
})
DataType::Categorical(_, _) | DataType::Enum(_, _) => c,
_ => c.to_physical_repr(),
})?;

if column.len() == 1 && broadcast_length != 1 {
polars_ensure!(
e.is_scalar(),
ShapeMismatch: "non-scalar expression produces broadcasting column",
);

return Ok(column.new_from_index(0, broadcast_length));
}

if broadcast_length != column.len() {
polars_ensure!(
broadcast_length == 1, ShapeMismatch:
"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",
broadcast_length, column.len()
);

needs_broadcast |= i > 0;
broadcast_length = column.len();
}

Ok(column)
})
.collect::<PolarsResult<Vec<_>>>()?;

if needs_broadcast {
for c in s_sort_by.iter_mut() {
if c.len() != broadcast_length {
*c = c.new_from_index(0, broadcast_length);
}
}
}

let options = self
.sort_options
.clone()
.with_order_descending_multi(descending)
.with_nulls_last_multi(nulls_last);

for i in 1..s_sort_by.len() {
polars_ensure!(
s_sort_by[0].len() == s_sort_by[i].len(),
expr = self.expr, ShapeMismatch:
"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",
s_sort_by[0].len(), s_sort_by[i].len()
);
}

s_sort_by[0]
.as_materialized_series()
.arg_sort_multiple(&s_sort_by[1..], &options)
Expand Down
9 changes: 5 additions & 4 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,11 +1072,12 @@ def test_sort_string_nulls() -> None:
]


@pytest.mark.may_fail_auto_streaming
def test_sort_by_unequal_lengths_7207() -> None:
df = pl.DataFrame({"a": [0, 1, 1, 0], "b": [3, 2, 3, 2]})
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col.a.sort_by(["a", 1]))
df = pl.DataFrame({"a": [0, 1, 1, 0]})
result = df.select(pl.arg_sort_by(["a", 1]))

expected = pl.DataFrame({"a": [0, 3, 1, 2]})
assert_frame_equal(result, expected, check_dtypes=False)


def test_sort_literals() -> None:
Expand Down

0 comments on commit 044bc51

Please sign in to comment.