Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust, python): Remove old cut/qcut #9763

Merged
merged 5 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 0 additions & 150 deletions polars/polars-algo/src/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,153 +99,3 @@ pub fn hist(s: &Series, bins: Option<&Series>, bin_count: Option<usize>) -> Resu
.fill_null(FillNullStrategy::Zero)?
.sort(["category"], false)
}

pub fn qcut(
s: &Series,
quantiles: &[f64],
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> PolarsResult<DataFrame> {
let s = s.cast(&DataType::Float64)?;

// amortize quantile computation
let s_sorted = s.sort(false);
let ca = s_sorted.f64().unwrap();

let mut bins = Vec::with_capacity(quantiles.len());
for quantile_level in quantiles {
if let Some(quantile) = ca.quantile(*quantile_level, QuantileInterpolOptions::Linear)? {
bins.push(quantile)
}
}

let bins = Series::new("", bins);
if maintain_order {
cut(
&s,
bins,
labels,
break_point_label,
category_label,
maintain_order,
)
} else {
// already sorted, saves an extra sort
cut(
&s_sorted,
bins,
labels,
break_point_label,
category_label,
maintain_order,
)
}
}

pub fn cut(
s: &Series,
mut bins: Series,
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> PolarsResult<DataFrame> {
let var_name = s.name();
let breakpoint_str = break_point_label.unwrap_or("break_point");
let category_str = category_label.unwrap_or("category");

let bins_len = bins.len();

bins.rename(breakpoint_str);

let mut s_bins = bins
.cast(&DataType::Float64)
.map_err(|_| PolarsError::ComputeError("expected numeric bins".into()))?
.extend_constant(AnyValue::Float64(f64::INFINITY), 1)?;
s_bins.set_sorted_flag(IsSorted::Ascending);
let cuts_df = df![
breakpoint_str => s_bins
]?;

let cuts_df = if let Some(labels) = labels {
polars_ensure!(
labels.len() == (bins_len + 1),
ShapeMismatch: "labels count must equal bins count",
);
cuts_df
.lazy()
.with_column(lit(Series::new(category_str, labels)))
} else {
cuts_df.lazy().with_column(
format_str(
"({}, {}]",
[
col(breakpoint_str).shift_and_fill(1, lit(f64::NEG_INFINITY)),
col(breakpoint_str),
],
)?
.alias(category_str),
)
}
.collect()?;

const ROW_COUNT: &str = "__POLARS_IDX";

let cuts = cuts_df
.lazy()
.with_columns([col(category_str).cast(DataType::Categorical(None))])
.collect()?;

let mut s = s.cast(&DataType::Float64)?;
let valids = if s.null_count() > 0 {
let valids = Some(s.is_not_null());
s = s.fill_null(FillNullStrategy::MaxBound).unwrap();
valids
} else {
None
};
let mut frame = s.clone().into_frame();

if maintain_order {
frame = frame.with_row_count(ROW_COUNT, None)?;
}

let mut out = frame.sort(vec![var_name], vec![false])?.join_asof(
&cuts,
var_name,
breakpoint_str,
AsofStrategy::Forward,
None,
None,
)?;

if maintain_order {
out = out.sort([ROW_COUNT], false)?.drop(ROW_COUNT).unwrap()
};

if let Some(mut valids) = valids {
if !maintain_order {
let idx = s.arg_sort(SortOptions {
nulls_last: true,
..Default::default()
});
valids = unsafe { valids.take_unchecked((&idx).into()) };
}

let arr = valids.downcast_iter().next().unwrap();
let validity = arr.values().clone();

// Safety: we don't change the length/dtype
unsafe {
for col in out.get_columns_mut() {
let mut s = col.rechunk();
let chunks = s.chunks_mut();
chunks[0] = chunks[0].with_validity(Some(validity.clone()));
*col = s;
}
}
}
Ok(out)
}
2 changes: 1 addition & 1 deletion polars/polars-algo/src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub use crate::{cut, hist};
pub use crate::hist;
15 changes: 13 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,15 @@ pub enum FunctionExpr {
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
},
#[cfg(feature = "cutqcut")]
QCut {
probs: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
},
ToPhysical,
#[cfg(feature = "random")]
Expand Down Expand Up @@ -554,19 +556,28 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
breaks,
labels,
left_closed,
} => map!(cut, breaks.clone(), labels.clone(), left_closed),
include_breaks,
} => map!(
cut,
breaks.clone(),
labels.clone(),
left_closed,
include_breaks
),
#[cfg(feature = "cutqcut")]
QCut {
probs,
labels,
left_closed,
allow_duplicates,
include_breaks,
} => map!(
qcut,
probs.clone(),
labels.clone(),
left_closed,
allow_duplicates
allow_duplicates,
include_breaks
),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Expand Down
11 changes: 10 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,11 +1465,18 @@ impl Expr {
}

#[cfg(feature = "cutqcut")]
pub fn cut(self, breaks: Vec<f64>, labels: Option<Vec<String>>, left_closed: bool) -> Expr {
pub fn cut(
self,
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
) -> Expr {
self.apply_private(FunctionExpr::Cut {
breaks,
labels,
left_closed,
include_breaks,
})
}

Expand All @@ -1480,12 +1487,14 @@ impl Expr {
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
) -> Expr {
self.apply_private(FunctionExpr::QCut {
probs,
labels,
left_closed,
allow_duplicates,
include_breaks,
})
}

Expand Down
80 changes: 60 additions & 20 deletions polars/polars-ops/src/series/ops/cut.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,68 @@
use std::cmp::PartialOrd;
use std::iter::once;

use polars_core::prelude::*;

fn map_cats(
s: &Series,
cutlabs: &[String],
sorted_breaks: &[f64],
left_closed: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect();

let out_name = format!("{}_bin", s.name());
let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len());
let s2 = s.cast(&DataType::Float64)?;
// It would be nice to parallelize this
let s_iter = s2.f64()?.into_iter();

let op = if left_closed {
PartialOrd::ge
} else {
PartialOrd::gt
};

if include_breaks {
// This is to replicate the behavior of the old buggy version that only worked on series and
// returned a dataframe. That included a column of the right endpoint of the interval. So we
// return a struct series instead which can be turned into a dataframe later.
let right_ends = [sorted_breaks, &[f64::INFINITY]].concat();
let mut brk_vals = PrimitiveChunkedBuilder::<Float64Type>::new("brk", s.len());
s_iter
.map(|opt| {
opt.filter(|x| !x.is_nan())
.map(|x| sorted_breaks.partition_point(|v| op(&x, v)))
})
.for_each(|idx| match idx {
None => {
bld.append_null();
brk_vals.append_null();
}
Some(idx) => unsafe {
bld.append_value(cl.get_unchecked(idx));
brk_vals.append_value(*right_ends.get_unchecked(idx));
},
});

let outvals = vec![brk_vals.finish().into_series(), bld.finish().into_series()];
Ok(StructChunked::new(&out_name, &outvals)?.into_series())
} else {
bld.drain_iter(s_iter.map(|opt| {
opt.filter(|x| !x.is_nan())
.map(|x| unsafe { *cl.get_unchecked(sorted_breaks.partition_point(|v| op(&x, v))) })
}));
Ok(bld.finish().into_series())
}
}

pub fn cut(
s: &Series,
breaks: Vec<f64>,
labels: Option<Vec<String>>,
left_closed: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
polars_ensure!(!breaks.is_empty(), ShapeMismatch: "Breaks are empty");
polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "Breaks cannot be NaN");
Expand Down Expand Up @@ -36,24 +92,7 @@ pub fn cut(
.collect::<Vec<String>>(),
};

let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect();
let s_flt = s.cast(&DataType::Float64)?;
let bin_iter = s_flt.f64()?.into_iter();

let out_name = format!("{}_bin", s.name());
let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len());
unsafe {
if left_closed {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x >= v)))
}));
} else {
bld.drain_iter(bin_iter.map(|opt| {
opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x > v)))
}));
}
}
Ok(bld.finish().into_series())
map_cats(s, &cutlabs, sorted_breaks, left_closed, include_breaks)
}

pub fn qcut(
Expand All @@ -62,6 +101,7 @@ pub fn qcut(
labels: Option<Vec<String>>,
left_closed: bool,
allow_duplicates: bool,
include_breaks: bool,
) -> PolarsResult<Series> {
let s = s.cast(&DataType::Float64)?;
let s2 = s.sort(false);
Expand Down Expand Up @@ -94,7 +134,7 @@ pub fn qcut(
}
};
qbreaks.dedup();
return cut(&s, qbreaks, lfilt, left_closed);
return cut(&s, qbreaks, lfilt, left_closed, include_breaks);
}
cut(&s, qbreaks, labels, left_closed)
cut(&s, qbreaks, labels, left_closed, include_breaks)
}
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Manipulation/selection
Expr.clip
Expr.clip_max
Expr.clip_min
Expr.cut
Expr.drop_nans
Expr.drop_nulls
Expr.explode
Expand All @@ -33,6 +34,7 @@ Manipulation/selection
Expr.lower_bound
Expr.map_dict
Expr.pipe
Expr.qcut
Expr.rechunk
Expr.reinterpret
Expr.repeat_by
Expand Down
Loading