Skip to content

Commit

Permalink
feat(rust, python): Improve cut and allow use in expressions (#9580)
Browse files Browse the repository at this point in the history
  • Loading branch information
magarick authored Jul 5, 2023
1 parent caabbb9 commit 9ff6908
Show file tree
Hide file tree
Showing 18 changed files with 420 additions and 28 deletions.
1 change: 1 addition & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ streaming = ["polars-lazy/streaming"]
fused = ["polars-ops/fused", "polars-lazy/fused"]
list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars-lazy/list_any_all"]
cutqcut = ["polars-lazy/cutqcut"]

test = [
"lazy",
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ serde = [
fused = ["polars-plan/fused", "polars-ops/fused"]
list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"]
cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"]

binary_encoding = ["polars-plan/binary_encoding"]

Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ coalesce = []
fused = []
list_sets = ["polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all"]
cutqcut = ["polars-ops/cutqcut"]

bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"]

Expand Down
38 changes: 38 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub(crate) use correlation::CorrelationMethod;
pub(crate) use fused::FusedOperator;
pub(super) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "cutqcut")]
use polars_ops::prelude::{cut, qcut};
#[cfg(feature = "random")]
pub(crate) use random::RandomMethod;
use schema::FieldsMapper;
Expand Down Expand Up @@ -198,6 +200,19 @@ pub enum FunctionExpr {
method: correlation::CorrelationMethod,
ddof: u8,
},
#[cfg(feature = "cutqcut")]
Cut {
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
},
#[cfg(feature = "cutqcut")]
QCut {
probs: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
},
ToPhysical,
#[cfg(feature = "random")]
Random {
Expand Down Expand Up @@ -301,6 +316,10 @@ impl Display for FunctionExpr {
ArrayExpr(af) => return Display::fmt(af, f),
ConcatExpr(_) => "concat_expr",
Correlation { method, .. } => return Display::fmt(method, f),
#[cfg(feature = "cutqcut")]
Cut { .. } => "cut",
#[cfg(feature = "cutqcut")]
QCut { .. } => "qcut",
ToPhysical => "to_physical",
#[cfg(feature = "random")]
Random { method, .. } => method.into(),
Expand Down Expand Up @@ -530,6 +549,25 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Fused(op) => map_as_slice!(fused::fused, op),
ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk),
Correlation { method, ddof } => map_as_slice!(correlation::corr, ddof, method),
#[cfg(feature = "cutqcut")]
Cut {
breaks,
labels,
left_closed,
} => map!(cut, breaks.clone(), labels.clone(), left_closed),
#[cfg(feature = "cutqcut")]
QCut {
probs,
labels,
left_closed,
allow_duplicates,
} => map!(
qcut,
probs.clone(),
labels.clone(),
left_closed,
allow_duplicates
),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Random {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ impl FunctionExpr {
Fused(_) => mapper.map_to_supertype(),
ConcatExpr(_) => mapper.map_to_supertype(),
Correlation { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "cutqcut")]
Cut { .. } => mapper.with_dtype(DataType::Categorical(None)),
#[cfg(feature = "cutqcut")]
QCut { .. } => mapper.with_dtype(DataType::Categorical(None)),
ToPhysical => mapper.to_physical_type(),
#[cfg(feature = "random")]
Random { .. } => mapper.with_same_dtype(),
Expand Down
33 changes: 33 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,39 @@ impl Expr {
.with_fmt("rank")
}

#[cfg(feature = "cutqcut")]
pub fn cut(self, breaks: Vec<f64>, labels: Option<Vec<String>>, left_closed: bool) -> Expr {
self.apply_private(FunctionExpr::Cut {
breaks,
labels,
left_closed,
})
.with_function_options(|mut opt| {
opt.allow_group_aware = false;
opt
})
}

#[cfg(feature = "cutqcut")]
pub fn qcut(
self,
probs: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
) -> Expr {
self.apply_private(FunctionExpr::QCut {
probs,
labels,
left_closed,
allow_duplicates,
})
.with_function_options(|mut opt| {
opt.allow_group_aware = false;
opt
})
}

#[cfg(feature = "diff")]
pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr {
self.apply_private(FunctionExpr::Diff(n, null_behavior))
Expand Down
10 changes: 5 additions & 5 deletions polars/polars-lazy/polars-plan/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ pub struct FunctionOptions {
/// Collect groups to a list and apply the function over the groups.
/// This can be important in aggregation context.
pub collect_groups: ApplyOptions,
// used for formatting, (only for anonymous functions)
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
pub fmt_str: &'static str,
/// There can be two ways of expanding wildcards:
///
/// Say the schema is 'a', 'b' and there is a function f
Expand All @@ -205,7 +208,6 @@ pub struct FunctionOptions {
///
/// this also accounts for regex expansion
pub input_wildcard_expansion: bool,

/// automatically explode on unit length it ran as final aggregation.
///
/// this is the case for aggregations like sum, min, covariance etc.
Expand All @@ -217,10 +219,6 @@ pub struct FunctionOptions {
/// head_1(x) -> {1}
/// sum(x) -> {4}
pub auto_explode: bool,
// used for formatting, (only for anonymous functions)
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
pub fmt_str: &'static str,

// if the expression and its inputs should be cast to supertypes
pub cast_to_supertypes: bool,
// apply physical expression may rename the output of this function
Expand All @@ -233,6 +231,7 @@ pub struct FunctionOptions {
// Validate the output of a `map`.
// this should always be true or we could OOB
pub check_lengths: UnsafeBool,
pub allow_group_aware: bool,
}

impl FunctionOptions {
Expand Down Expand Up @@ -265,6 +264,7 @@ impl Default for FunctionOptions {
pass_name_to_apply: false,
changes_length: false,
check_lengths: UnsafeBool(true),
allow_group_aware: true,
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct ApplyExpr {
pub input_schema: Option<SchemaRef>,
pub allow_threading: bool,
pub check_lengths: bool,
pub allow_group_aware: bool,
}

impl ApplyExpr {
Expand All @@ -46,6 +47,7 @@ impl ApplyExpr {
input_schema: None,
allow_threading: true,
check_lengths: true,
allow_group_aware: true,
}
}

Expand Down Expand Up @@ -280,6 +282,11 @@ impl PhysicalExpr for ApplyExpr {
groups: &'a GroupsProxy,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
polars_ensure!(
self.allow_group_aware,
expr = self.expr,
ComputeError: "this expression cannot run in the groupby context",
);
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ pub(crate) fn create_physical_expr(
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
}
Function {
Expand Down Expand Up @@ -499,6 +500,7 @@ pub(crate) fn create_physical_expr(
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
}
Slice {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ is_first = []
is_unique = []
approx_unique = []
fused = []
cutqcut = ["dtype-categorical"]

# extra utilities for BinaryChunked
binary_encoding = ["base64", "hex"]
Expand Down
100 changes: 100 additions & 0 deletions polars/polars-ops/src/series/ops/cut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::iter::once;

use polars_core::prelude::*;

pub fn cut(
s: &Series,
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
) -> PolarsResult<Series> {
polars_ensure!(!breaks.is_empty(), ShapeMismatch: "Breaks are empty");
polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "Breaks cannot be NaN");
// Breaks must be sorted to cut inputs properly.
let mut breaks = breaks;
let sorted_breaks = breaks.as_mut_slice();
sorted_breaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
polars_ensure!(sorted_breaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "Breaks are not unique");

polars_ensure!(sorted_breaks[0] > f64::NEG_INFINITY, ComputeError: "Don't include -inf in breaks");
polars_ensure!(sorted_breaks[sorted_breaks.len() - 1] < f64::INFINITY, ComputeError: "Don't include inf in breaks");

let cutlabs = match labels {
Some(ll) => {
polars_ensure!(ll.len() == sorted_breaks.len() + 1, ShapeMismatch: "Provide nbreaks + 1 labels");
ll
}
None => (once(&f64::NEG_INFINITY).chain(sorted_breaks.iter()))
.zip(sorted_breaks.iter().chain(once(&f64::INFINITY)))
.map(|v| {
if left_closed {
format!("[{}, {})", v.0, v.1)
} else {
format!("({}, {}]", v.0, v.1)
}
})
.collect::<Vec<String>>(),
};

let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect();
let s_flt = s.cast(&DataType::Float64)?;
let bin_iter = s_flt.f64()?.into_iter();

let out_name = format!("{}_bin", s.name());
let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len());
unsafe {
if left_closed {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x >= v)))
}));
} else {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x > v)))
}));
}
}
Ok(bld.finish().into_series())
}

pub fn qcut(
s: &Series,
probs: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
) -> PolarsResult<Series> {
let s = s.cast(&DataType::Float64)?;
let s2 = s.sort(false);
let ca = s2.f64()?;
let f = |&p| {
ca.quantile(p, QuantileInterpolOptions::Linear)
.unwrap()
.unwrap()
};
let mut qbreaks: Vec<_> = probs.iter().map(f).collect();
qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());

// When probs are spaced too closely for the number of repeated values in the distribution
// some quantiles may be duplicated. The only thing to do if we want to go on, is to drop
// the repeated values and live with some bins being larger than intended.
if allow_duplicates {
let lfilt = match labels {
None => None,
Some(ll) => {
polars_ensure!(ll.len() == qbreaks.len() + 1,
ShapeMismatch: "Wrong number of labels");
let blen = ll.len();
Some(
ll.into_iter()
.enumerate()
.filter(|(i, _)| *i == 0 || *i == blen || qbreaks[*i] != qbreaks[i - 1])
.unzip::<_, _, Vec<_>, Vec<_>>()
.1,
)
}
};
qbreaks.dedup();
return cut(&s, qbreaks, lfilt, left_closed);
}
cut(&s, qbreaks, labels, left_closed)
}
4 changes: 4 additions & 0 deletions polars/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod approx_algo;
#[cfg(feature = "approx_unique")]
mod approx_unique;
mod arg_min_max;
#[cfg(feature = "cutqcut")]
mod cut;
#[cfg(feature = "round_series")]
mod floor_divide;
#[cfg(feature = "fused")]
Expand All @@ -24,6 +26,8 @@ pub use approx_algo::*;
#[cfg(feature = "approx_unique")]
pub use approx_unique::*;
pub use arg_min_max::ArgAgg;
#[cfg(feature = "cutqcut")]
pub use cut::*;
#[cfg(feature = "round_series")]
pub use floor_divide::*;
#[cfg(feature = "fused")]
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ list_count = ["polars/list_count"]
binary_encoding = ["polars/binary_encoding"]
list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars/list_any_all"]
cutqcut = ["polars/cutqcut"]

all = [
"json",
Expand Down Expand Up @@ -105,6 +106,7 @@ all = [
"list_count",
"list_sets",
"list_any_all",
"cutqcut",
]

# we cannot conditionally activate simd
Expand Down
Loading

0 comments on commit 9ff6908

Please sign in to comment.