Skip to content

Commit

Permalink
refactor(rust, python): Remove old cut/qcut (#9763)
Browse files Browse the repository at this point in the history
  • Loading branch information
magarick authored Jul 11, 2023
1 parent 5b92319 commit f5f0630
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 345 deletions.
150 changes: 0 additions & 150 deletions polars/polars-algo/src/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,153 +99,3 @@ pub fn hist(s: &Series, bins: Option<&Series>, bin_count: Option<usize>) -> Resu
.fill_null(FillNullStrategy::Zero)?
.sort(["category"], false)
}

pub fn qcut(
s: &Series,
quantiles: &[f64],
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> PolarsResult<DataFrame> {
let s = s.cast(&DataType::Float64)?;

// amortize quantile computation
let s_sorted = s.sort(false);
let ca = s_sorted.f64().unwrap();

let mut bins = Vec::with_capacity(quantiles.len());
for quantile_level in quantiles {
if let Some(quantile) = ca.quantile(*quantile_level, QuantileInterpolOptions::Linear)? {
bins.push(quantile)
}
}

let bins = Series::new("", bins);
if maintain_order {
cut(
&s,
bins,
labels,
break_point_label,
category_label,
maintain_order,
)
} else {
// already sorted, saves an extra sort
cut(
&s_sorted,
bins,
labels,
break_point_label,
category_label,
maintain_order,
)
}
}

pub fn cut(
s: &Series,
mut bins: Series,
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> PolarsResult<DataFrame> {
let var_name = s.name();
let breakpoint_str = break_point_label.unwrap_or("break_point");
let category_str = category_label.unwrap_or("category");

let bins_len = bins.len();

bins.rename(breakpoint_str);

let mut s_bins = bins
.cast(&DataType::Float64)
.map_err(|_| PolarsError::ComputeError("expected numeric bins".into()))?
.extend_constant(AnyValue::Float64(f64::INFINITY), 1)?;
s_bins.set_sorted_flag(IsSorted::Ascending);
let cuts_df = df![
breakpoint_str => s_bins
]?;

let cuts_df = if let Some(labels) = labels {
polars_ensure!(
labels.len() == (bins_len + 1),
ShapeMismatch: "labels count must equal bins count",
);
cuts_df
.lazy()
.with_column(lit(Series::new(category_str, labels)))
} else {
cuts_df.lazy().with_column(
format_str(
"({}, {}]",
[
col(breakpoint_str).shift_and_fill(1, lit(f64::NEG_INFINITY)),
col(breakpoint_str),
],
)?
.alias(category_str),
)
}
.collect()?;

const ROW_COUNT: &str = "__POLARS_IDX";

let cuts = cuts_df
.lazy()
.with_columns([col(category_str).cast(DataType::Categorical(None))])
.collect()?;

let mut s = s.cast(&DataType::Float64)?;
let valids = if s.null_count() > 0 {
let valids = Some(s.is_not_null());
s = s.fill_null(FillNullStrategy::MaxBound).unwrap();
valids
} else {
None
};
let mut frame = s.clone().into_frame();

if maintain_order {
frame = frame.with_row_count(ROW_COUNT, None)?;
}

let mut out = frame.sort(vec![var_name], vec![false])?.join_asof(
&cuts,
var_name,
breakpoint_str,
AsofStrategy::Forward,
None,
None,
)?;

if maintain_order {
out = out.sort([ROW_COUNT], false)?.drop(ROW_COUNT).unwrap()
};

if let Some(mut valids) = valids {
if !maintain_order {
let idx = s.arg_sort(SortOptions {
nulls_last: true,
..Default::default()
});
valids = unsafe { valids.take_unchecked((&idx).into()) };
}

let arr = valids.downcast_iter().next().unwrap();
let validity = arr.values().clone();

// Safety: we don't change the length/dtype
unsafe {
for col in out.get_columns_mut() {
let mut s = col.rechunk();
let chunks = s.chunks_mut();
chunks[0] = chunks[0].with_validity(Some(validity.clone()));
*col = s;
}
}
}
Ok(out)
}
2 changes: 1 addition & 1 deletion polars/polars-algo/src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub use crate::{cut, hist};
pub use crate::hist;
15 changes: 13 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,15 @@ pub enum FunctionExpr {
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
},
#[cfg(feature = "cutqcut")]
QCut {
probs: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
},
ToPhysical,
#[cfg(feature = "random")]
Expand Down Expand Up @@ -554,19 +556,28 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
breaks,
labels,
left_closed,
} => map!(cut, breaks.clone(), labels.clone(), left_closed),
include_breaks,
} => map!(
cut,
breaks.clone(),
labels.clone(),
left_closed,
include_breaks
),
#[cfg(feature = "cutqcut")]
QCut {
probs,
labels,
left_closed,
allow_duplicates,
include_breaks,
} => map!(
qcut,
probs.clone(),
labels.clone(),
left_closed,
allow_duplicates
allow_duplicates,
include_breaks
),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Expand Down
11 changes: 10 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,11 +1465,18 @@ impl Expr {
}

#[cfg(feature = "cutqcut")]
pub fn cut(self, breaks: Vec<f64>, labels: Option<Vec<String>>, left_closed: bool) -> Expr {
pub fn cut(
self,
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
) -> Expr {
self.apply_private(FunctionExpr::Cut {
breaks,
labels,
left_closed,
include_breaks,
})
}

Expand All @@ -1480,12 +1487,14 @@ impl Expr {
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
) -> Expr {
self.apply_private(FunctionExpr::QCut {
probs,
labels,
left_closed,
allow_duplicates,
include_breaks,
})
}

Expand Down
80 changes: 60 additions & 20 deletions polars/polars-ops/src/series/ops/cut.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,68 @@
use std::cmp::PartialOrd;
use std::iter::once;

use polars_core::prelude::*;

fn map_cats(
s: &Series,
cutlabs: &[String],
sorted_breaks: &[f64],
left_closed: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect();

let out_name = format!("{}_bin", s.name());
let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len());
let s2 = s.cast(&DataType::Float64)?;
// It would be nice to parallelize this
let s_iter = s2.f64()?.into_iter();

let op = if left_closed {
PartialOrd::ge
} else {
PartialOrd::gt
};

if include_breaks {
// This is to replicate the behavior of the old buggy version that only worked on series and
// returned a dataframe. That included a column of the right endpoint of the interval. So we
// return a struct series instead which can be turned into a dataframe later.
let right_ends = [sorted_breaks, &[f64::INFINITY]].concat();
let mut brk_vals = PrimitiveChunkedBuilder::<Float64Type>::new("brk", s.len());
s_iter
.map(|opt| {
opt.filter(|x| !x.is_nan())
.map(|x| sorted_breaks.partition_point(|v| op(&x, v)))
})
.for_each(|idx| match idx {
None => {
bld.append_null();
brk_vals.append_null();
}
Some(idx) => unsafe {
bld.append_value(cl.get_unchecked(idx));
brk_vals.append_value(*right_ends.get_unchecked(idx));
},
});

let outvals = vec![brk_vals.finish().into_series(), bld.finish().into_series()];
Ok(StructChunked::new(&out_name, &outvals)?.into_series())
} else {
bld.drain_iter(s_iter.map(|opt| {
opt.filter(|x| !x.is_nan())
.map(|x| unsafe { *cl.get_unchecked(sorted_breaks.partition_point(|v| op(&x, v))) })
}));
Ok(bld.finish().into_series())
}
}

pub fn cut(
s: &Series,
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
polars_ensure!(!breaks.is_empty(), ShapeMismatch: "Breaks are empty");
polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "Breaks cannot be NaN");
Expand Down Expand Up @@ -36,24 +92,7 @@ pub fn cut(
.collect::<Vec<String>>(),
};

let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect();
let s_flt = s.cast(&DataType::Float64)?;
let bin_iter = s_flt.f64()?.into_iter();

let out_name = format!("{}_bin", s.name());
let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len());
unsafe {
if left_closed {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x >= v)))
}));
} else {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x > v)))
}));
}
}
Ok(bld.finish().into_series())
map_cats(s, &cutlabs, sorted_breaks, left_closed, include_breaks)
}

pub fn qcut(
Expand All @@ -62,6 +101,7 @@ pub fn qcut(
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
let s = s.cast(&DataType::Float64)?;
let s2 = s.sort(false);
Expand Down Expand Up @@ -94,7 +134,7 @@ pub fn qcut(
}
};
qbreaks.dedup();
return cut(&s, qbreaks, lfilt, left_closed);
return cut(&s, qbreaks, lfilt, left_closed, include_breaks);
}
cut(&s, qbreaks, labels, left_closed)
cut(&s, qbreaks, labels, left_closed, include_breaks)
}
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expressions/modify_select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Manipulation/selection
Expr.clip
Expr.clip_max
Expr.clip_min
Expr.cut
Expr.drop_nans
Expr.drop_nulls
Expr.explode
Expand All @@ -33,6 +34,7 @@ Manipulation/selection
Expr.lower_bound
Expr.map_dict
Expr.pipe
Expr.qcut
Expr.rechunk
Expr.reinterpret
Expr.repeat_by
Expand Down
Loading

0 comments on commit f5f0630

Please sign in to comment.