Skip to content

Commit

Permalink
feat(rust, python): add maintain_order argument to sort/top_k/`…
Browse files Browse the repository at this point in the history
…bottom_k` (#9672)
  • Loading branch information
CloseChoice authored Jul 12, 2023
1 parent 5a2129c commit a1d5a22
Show file tree
Hide file tree
Showing 35 changed files with 265 additions and 53 deletions.
2 changes: 1 addition & 1 deletion polars/polars-algo/src/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,5 @@ pub fn hist(s: &Series, bins: Option<&Series>, bin_count: Option<usize>) -> Resu

cuts.left_join(&out, [category_str], [category_str])?
.fill_null(FillNullStrategy::Zero)?
.sort(["category"], false)
.sort(["category"], false, false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ impl CategoricalChunked {
counts.rename("counts");
let cols = vec![values.into_series(), counts.into_series()];
let df = DataFrame::new_no_checks(cols);
df.sort(["counts"], true)
df.sort(["counts"], true, false)
}
}
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ pub struct SortOptions {
pub descending: bool,
pub nulls_last: bool,
pub multithreaded: bool,
pub maintain_order: bool,
}

#[derive(Clone)]
Expand All @@ -495,6 +496,7 @@ impl Default for SortOptions {
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-core/src/chunked_array/ops/sort/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl CategoricalChunked {
nulls_last: false,
descending,
multithreaded: true,
maintain_order: false,
})
}

Expand Down Expand Up @@ -202,12 +203,12 @@ mod test {
"vals" => [1, 1, 2, 2]
]?;

let out = df.sort(["cat", "vals"], vec![false, false])?;
let out = df.sort(["cat", "vals"], vec![false, false], false)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["a", "a", "b", "c"]);

let out = df.sort(["vals", "cat"], vec![false, false])?;
let out = df.sort(["vals", "cat"], vec![false, false], false)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["b", "c", "a", "a"]);
Expand Down
17 changes: 13 additions & 4 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
descending,
nulls_last: false,
multithreaded: true,
maintain_order: false,
})
}

Expand Down Expand Up @@ -545,6 +546,7 @@ impl ChunkSort<BinaryType> for BinaryChunked {
descending,
nulls_last: false,
multithreaded: true,
maintain_order: false,
})
}

Expand Down Expand Up @@ -637,6 +639,7 @@ impl ChunkSort<BooleanType> for BooleanChunked {
descending,
nulls_last: false,
multithreaded: true,
maintain_order: false,
})
}

Expand Down Expand Up @@ -775,6 +778,7 @@ mod test {
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});
assert_eq!(
Vec::from(&out),
Expand All @@ -793,6 +797,7 @@ mod test {
descending: false,
nulls_last: true,
multithreaded: true,
maintain_order: false,
});
assert_eq!(
Vec::from(&out),
Expand All @@ -817,7 +822,7 @@ mod test {
let c = Utf8Chunked::new("c", &["a", "b", "c", "d", "e", "f", "g", "h"]);
let df = DataFrame::new(vec![a.into_series(), b.into_series(), c.into_series()])?;

let out = df.sort(["a", "b", "c"], false)?;
let out = df.sort(["a", "b", "c"], false, false)?;
assert_eq!(
Vec::from(out.column("b")?.i64()?),
&[
Expand All @@ -837,7 +842,7 @@ mod test {
let b = Int32Chunked::new("b", &[5, 4, 2, 3, 4, 5]).into_series();
let df = DataFrame::new(vec![a, b])?;

let out = df.sort(["a", "b"], false)?;
let out = df.sort(["a", "b"], false, false)?;
let expected = df!(
"a" => ["a", "a", "b", "b", "c", "c"],
"b" => [3, 5, 4, 4, 2, 5]
Expand All @@ -849,14 +854,14 @@ mod test {
"values" => ["a", "a", "b"]
)?;

let out = df.sort(["groups", "values"], vec![true, false])?;
let out = df.sort(["groups", "values"], vec![true, false], false)?;
let expected = df!(
"groups" => [3, 2, 1],
"values" => ["b", "a", "a"]
)?;
assert!(out.frame_equal(&expected));

let out = df.sort(["values", "groups"], vec![false, true])?;
let out = df.sort(["values", "groups"], vec![false, true], false)?;
let expected = df!(
"groups" => [2, 1, 3],
"values" => ["a", "a", "b"]
Expand All @@ -873,6 +878,7 @@ mod test {
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});
let expected = &[None, None, Some("a"), Some("b"), Some("c")];
assert_eq!(Vec::from(&out), expected);
Expand All @@ -881,6 +887,7 @@ mod test {
descending: true,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});

let expected = &[None, None, Some("c"), Some("b"), Some("a")];
Expand All @@ -890,6 +897,7 @@ mod test {
descending: false,
nulls_last: true,
multithreaded: true,
maintain_order: false,
});
let expected = &[Some("a"), Some("b"), Some("c"), None, None];
assert_eq!(Vec::from(&out), expected);
Expand All @@ -898,6 +906,7 @@ mod test {
descending: true,
nulls_last: true,
multithreaded: true,
maintain_order: false,
});
let expected = &[Some("c"), Some("b"), Some("a"), None, None];
assert_eq!(Vec::from(&out), expected);
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ mod test {
// Use of deprecated `sum()` for testing purposes
#[allow(deprecated)]
let res = df.groupby(["flt"]).unwrap().sum().unwrap();
let res = res.sort(["flt"], false).unwrap();
let res = res.sort(["flt"], false, false).unwrap();
assert_eq!(
Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
&[Some(2), Some(2), Some(1)]
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-core/src/frame/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ pub fn _sort_or_hash_inner(
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});
let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() };
let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right);
Expand All @@ -250,6 +251,7 @@ pub fn _sort_or_hash_inner(
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});
let s_left = unsafe { s_left.take_unchecked(&sort_idx).unwrap() };
let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right);
Expand Down Expand Up @@ -318,6 +320,7 @@ pub(super) fn sort_or_hash_left(
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
});
let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() };

Expand Down
17 changes: 11 additions & 6 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1841,11 +1841,12 @@ impl DataFrame {
&mut self,
by_column: impl IntoVec<SmartString>,
descending: impl IntoVec<bool>,
maintain_order: bool,
) -> PolarsResult<&mut Self> {
let by_column = self.select_series(by_column)?;
let descending = descending.into_vec();
self.columns = self
.sort_impl(by_column, descending, false, None, true)?
.sort_impl(by_column, descending, false, maintain_order, None, true)?
.columns;
Ok(self)
}
Expand All @@ -1856,6 +1857,7 @@ impl DataFrame {
by_column: Vec<Series>,
descending: Vec<bool>,
nulls_last: bool,
maintain_order: bool,
slice: Option<(i64, usize)>,
parallel: bool,
) -> PolarsResult<Self> {
Expand Down Expand Up @@ -1890,7 +1892,7 @@ impl DataFrame {
}

if let Some((0, k)) = slice {
return self.top_k_impl(k, descending, by_column, nulls_last);
return self.top_k_impl(k, descending, by_column, nulls_last, maintain_order);
}

#[cfg(feature = "dtype-struct")]
Expand All @@ -1912,6 +1914,7 @@ impl DataFrame {
descending: descending[0],
nulls_last,
multithreaded: parallel,
maintain_order,
};
// fast path for a frame with a single series
// no need to compute the sort indices and then take by these indices
Expand Down Expand Up @@ -1959,20 +1962,21 @@ impl DataFrame {
/// ```
/// # use polars_core::prelude::*;
/// fn sort_example(df: &DataFrame, descending: bool) -> PolarsResult<DataFrame> {
/// df.sort(["a"], descending)
/// df.sort(["a"], descending, false)
/// }
///
/// fn sort_by_multiple_columns_example(df: &DataFrame) -> PolarsResult<DataFrame> {
/// df.sort(&["a", "b"], vec![false, true])
/// df.sort(&["a", "b"], vec![false, true], false)
/// }
/// ```
pub fn sort(
&self,
by_column: impl IntoVec<SmartString>,
descending: impl IntoVec<bool>,
maintain_order: bool,
) -> PolarsResult<Self> {
let mut df = self.clone();
df.sort_in_place(by_column, descending)?;
df.sort_in_place(by_column, descending, maintain_order)?;
Ok(df)
}

Expand All @@ -1986,6 +1990,7 @@ impl DataFrame {
by_column,
descending,
options.nulls_last,
options.maintain_order,
None,
options.multithreaded,
)?
Expand Down Expand Up @@ -3603,7 +3608,7 @@ mod test {
let df = df
.unique_stable(None, UniqueKeepStrategy::First, None)
.unwrap()
.sort(["flt"], false)
.sort(["flt"], false, false)
.unwrap();
let valid = df! {
"flt" => [1., 2., 3.],
Expand Down
13 changes: 11 additions & 2 deletions polars/polars-core/src/frame/top_k.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl DataFrame {
) -> PolarsResult<DataFrame> {
let by_column = self.select_series(by_column)?;
let descending = descending.into_vec();
self.top_k_impl(k, descending, by_column, false)
self.top_k_impl(k, descending, by_column, false, false)
}

pub(crate) fn top_k_impl(
Expand All @@ -55,6 +55,7 @@ impl DataFrame {
mut descending: Vec<bool>,
by_column: Vec<Series>,
nulls_last: bool,
maintain_order: bool,
) -> PolarsResult<DataFrame> {
_broadcast_descending(by_column.len(), &mut descending);
let encoded = _get_rows_encoded(&by_column, &descending, nulls_last)?;
Expand All @@ -66,8 +67,16 @@ impl DataFrame {
.collect::<Vec<_>>();

let sorted = if k >= self.height() {
rows.sort_unstable();
if maintain_order {
rows.sort();
} else {
rows.sort_unstable();
}
&rows
} else if maintain_order {
// todo: maybe there is some more efficient method, comparable to select_nth_unstable
rows.sort();
&rows[..k]
} else {
let (lower, _el, _upper) = rows.select_nth_unstable(k);
lower.sort_unstable();
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ impl SeriesTrait for SeriesWrap<StructChunked> {
df.columns.clone(),
desc,
options.nulls_last,
options.maintain_order,
None,
options.multithreaded,
)
Expand Down
10 changes: 9 additions & 1 deletion polars/polars-lazy/polars-pipe/src/executors/sinks/sort/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ impl Sink for SortSink {
descending: self.sort_args.descending[0],
nulls_last: self.sort_args.nulls_last,
multithreaded: true,
maintain_order: self.sort_args.maintain_order,
});

block_thread_until_io_thread_done(io_thread);
Expand Down Expand Up @@ -216,5 +217,12 @@ pub(super) fn sort_accumulated(
slice: Option<(i64, usize)>,
) -> PolarsResult<DataFrame> {
let sort_column = df.get_columns()[sort_idx].clone();
df.sort_impl(vec![sort_column], vec![descending], false, slice, true)
df.sort_impl(
vec![sort_column],
vec![descending],
false,
false,
slice,
true,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ impl SortSinkMultiple {
descending: vec![false],
nulls_last: false,
slice: sort_args.slice,
maintain_order: false,
},
Arc::new(schema),
));
Expand Down
9 changes: 8 additions & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,13 @@ impl LogicalPlanBuilder {
.into()
}

pub fn sort(self, by_column: Vec<Expr>, descending: Vec<bool>, null_last: bool) -> Self {
pub fn sort(
self,
by_column: Vec<Expr>,
descending: Vec<bool>,
null_last: bool,
maintain_order: bool,
) -> Self {
let schema = try_delayed!(self.0.schema(), &self.0, into);
let by_column = try_delayed!(rewrite_projections(by_column, &schema, &[]), &self.0, into);
LogicalPlan::Sort {
Expand All @@ -643,6 +649,7 @@ impl LogicalPlanBuilder {
descending,
nulls_last: null_last,
slice: None,
maintain_order,
},
}
.into()
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-plan/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ pub struct SortArguments {
pub descending: Vec<bool>,
pub nulls_last: bool,
pub slice: Option<(i64, usize)>,
pub maintain_order: bool,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
Expand Down
Loading

0 comments on commit a1d5a22

Please sign in to comment.