diff --git a/polars/Cargo.toml b/polars/Cargo.toml index 739c222ff630..0fe672f612c1 100644 --- a/polars/Cargo.toml +++ b/polars/Cargo.toml @@ -160,6 +160,7 @@ streaming = ["polars-lazy/streaming"] fused = ["polars-ops/fused", "polars-lazy/fused"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars-lazy/list_any_all"] +cutqcut = ["polars-lazy/cutqcut"] test = [ "lazy", diff --git a/polars/polars-lazy/Cargo.toml b/polars/polars-lazy/Cargo.toml index 4bf30f934b2f..1ef35cf92c97 100644 --- a/polars/polars-lazy/Cargo.toml +++ b/polars/polars-lazy/Cargo.toml @@ -133,6 +133,7 @@ serde = [ fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] +cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] binary_encoding = ["polars-plan/binary_encoding"] diff --git a/polars/polars-lazy/polars-plan/Cargo.toml b/polars/polars-lazy/polars-plan/Cargo.toml index 34da3b805b8b..3161b185b9dc 100644 --- a/polars/polars-lazy/polars-plan/Cargo.toml +++ b/polars/polars-lazy/polars-plan/Cargo.toml @@ -121,6 +121,7 @@ coalesce = [] fused = [] list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] +cutqcut = ["polars-ops/cutqcut"] bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"] diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs index fd2df180817d..7668e6760444 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs @@ -63,6 +63,8 @@ pub(crate) use correlation::CorrelationMethod; pub(crate) use fused::FusedOperator; pub(super) use list::ListFunction; use polars_core::prelude::*; +#[cfg(feature = "cutqcut")] +use polars_ops::prelude::{cut, qcut}; #[cfg(feature = "random")] pub(crate) use random::RandomMethod; use schema::FieldsMapper; @@ -198,6 +200,19 @@ pub enum FunctionExpr { method: correlation::CorrelationMethod, ddof: u8, }, + #[cfg(feature = "cutqcut")] + Cut { + breaks: Vec, + labels: Option>, + left_closed: bool, + }, + #[cfg(feature = "cutqcut")] + QCut { + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + }, ToPhysical, #[cfg(feature = "random")] Random { @@ -301,6 +316,10 @@ impl Display for FunctionExpr { ArrayExpr(af) => return Display::fmt(af, f), ConcatExpr(_) => "concat_expr", Correlation { method, .. } => return Display::fmt(method, f), + #[cfg(feature = "cutqcut")] + Cut { .. } => "cut", + #[cfg(feature = "cutqcut")] + QCut { .. } => "qcut", ToPhysical => "to_physical", #[cfg(feature = "random")] Random { method, .. } => method.into(), @@ -530,6 +549,25 @@ impl From for SpecialEq> { Fused(op) => map_as_slice!(fused::fused, op), ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk), Correlation { method, ddof } => map_as_slice!(correlation::corr, ddof, method), + #[cfg(feature = "cutqcut")] + Cut { + breaks, + labels, + left_closed, + } => map!(cut, breaks.clone(), labels.clone(), left_closed), + #[cfg(feature = "cutqcut")] + QCut { + probs, + labels, + left_closed, + allow_duplicates, + } => map!( + qcut, + probs.clone(), + labels.clone(), + left_closed, + allow_duplicates + ), ToPhysical => map!(dispatch::to_physical), #[cfg(feature = "random")] Random { diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs index f389060404ad..a4ad30a1bd33 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs @@ -234,6 +234,10 @@ impl FunctionExpr { Fused(_) => mapper.map_to_supertype(), ConcatExpr(_) => mapper.map_to_supertype(), Correlation { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "cutqcut")] + Cut { .. } => mapper.with_dtype(DataType::Categorical(None)), + #[cfg(feature = "cutqcut")] + QCut { .. } => mapper.with_dtype(DataType::Categorical(None)), ToPhysical => mapper.to_physical_type(), #[cfg(feature = "random")] Random { .. } => mapper.with_same_dtype(), diff --git a/polars/polars-lazy/polars-plan/src/dsl/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/mod.rs index 77d5dd20e719..3fd04e99c032 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/mod.rs @@ -1464,6 +1464,39 @@ impl Expr { .with_fmt("rank") } + #[cfg(feature = "cutqcut")] + pub fn cut(self, breaks: Vec, labels: Option>, left_closed: bool) -> Expr { + self.apply_private(FunctionExpr::Cut { + breaks, + labels, + left_closed, + }) + .with_function_options(|mut opt| { + opt.allow_group_aware = false; + opt + }) + } + + #[cfg(feature = "cutqcut")] + pub fn qcut( + self, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + ) -> Expr { + self.apply_private(FunctionExpr::QCut { + probs, + labels, + left_closed, + allow_duplicates, + }) + .with_function_options(|mut opt| { + opt.allow_group_aware = false; + opt + }) + } + #[cfg(feature = "diff")] pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr { self.apply_private(FunctionExpr::Diff(n, null_behavior)) diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/options.rs b/polars/polars-lazy/polars-plan/src/logical_plan/options.rs index e771d4dc24fd..abb5c25bf18f 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/options.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/options.rs @@ -190,6 +190,9 @@ pub struct FunctionOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. pub collect_groups: ApplyOptions, + // used for formatting, (only for anonymous functions) + #[cfg_attr(feature = "serde", serde(skip_deserializing))] + pub fmt_str: &'static str, /// There can be two ways of expanding wildcards: /// /// Say the schema is 'a', 'b' and there is a function f @@ -205,7 +208,6 @@ pub struct FunctionOptions { /// /// this also accounts for regex expansion pub input_wildcard_expansion: bool, - /// automatically explode on unit length it ran as final aggregation. /// /// this is the case for aggregations like sum, min, covariance etc. @@ -217,10 +219,6 @@ pub struct FunctionOptions { /// head_1(x) -> {1} /// sum(x) -> {4} pub auto_explode: bool, - // used for formatting, (only for anonymous functions) - #[cfg_attr(feature = "serde", serde(skip_deserializing))] - pub fmt_str: &'static str, - // if the expression and its inputs should be cast to supertypes pub cast_to_supertypes: bool, // apply physical expression may rename the output of this function @@ -233,6 +231,7 @@ pub struct FunctionOptions { // Validate the output of a `map`. // this should always be true or we could OOB pub check_lengths: UnsafeBool, + pub allow_group_aware: bool, } impl FunctionOptions { @@ -265,6 +264,7 @@ impl Default for FunctionOptions { pass_name_to_apply: false, changes_length: false, check_lengths: UnsafeBool(true), + allow_group_aware: true, } } } diff --git a/polars/polars-lazy/src/physical_plan/expressions/apply.rs b/polars/polars-lazy/src/physical_plan/expressions/apply.rs index 2a840d71d284..cae2bcbed458 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/apply.rs @@ -26,6 +26,7 @@ pub struct ApplyExpr { pub input_schema: Option, pub allow_threading: bool, pub check_lengths: bool, + pub allow_group_aware: bool, } impl ApplyExpr { @@ -46,6 +47,7 @@ impl ApplyExpr { input_schema: None, allow_threading: true, check_lengths: true, + allow_group_aware: true, } } @@ -280,6 +282,11 @@ impl PhysicalExpr for ApplyExpr { groups: &'a GroupsProxy, state: &ExecutionState, ) -> PolarsResult> { + polars_ensure!( + self.allow_group_aware, + expr = self.expr, + ComputeError: "this expression cannot run in the groupby context", + ); if self.inputs.len() == 1 { let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?; diff --git a/polars/polars-lazy/src/physical_plan/planner/expr.rs b/polars/polars-lazy/src/physical_plan/planner/expr.rs index dba817d10116..6e6431622ec7 100644 --- a/polars/polars-lazy/src/physical_plan/planner/expr.rs +++ b/polars/polars-lazy/src/physical_plan/planner/expr.rs @@ -464,6 +464,7 @@ pub(crate) fn create_physical_expr( input_schema: schema.cloned(), allow_threading: !state.has_cache, check_lengths: options.check_lengths(), + allow_group_aware: options.allow_group_aware, })) } Function { @@ -499,6 +500,7 @@ pub(crate) fn create_physical_expr( input_schema: schema.cloned(), allow_threading: !state.has_cache, check_lengths: options.check_lengths(), + allow_group_aware: options.allow_group_aware, })) } Slice { diff --git a/polars/polars-ops/Cargo.toml b/polars/polars-ops/Cargo.toml index f7de9ffd986e..26a5e11a0518 100644 --- a/polars/polars-ops/Cargo.toml +++ b/polars/polars-ops/Cargo.toml @@ -50,6 +50,7 @@ is_first = [] is_unique = [] approx_unique = [] fused = [] +cutqcut = ["dtype-categorical"] # extra utilities for BinaryChunked binary_encoding = ["base64", "hex"] diff --git a/polars/polars-ops/src/series/ops/cut.rs b/polars/polars-ops/src/series/ops/cut.rs new file mode 100644 index 000000000000..82be7be970b0 --- /dev/null +++ b/polars/polars-ops/src/series/ops/cut.rs @@ -0,0 +1,100 @@ +use std::iter::once; + +use polars_core::prelude::*; + +pub fn cut( + s: &Series, + breaks: Vec, + labels: Option>, + left_closed: bool, +) -> PolarsResult { + polars_ensure!(!breaks.is_empty(), ShapeMismatch: "Breaks are empty"); + polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "Breaks cannot be NaN"); + // Breaks must be sorted to cut inputs properly. + let mut breaks = breaks; + let sorted_breaks = breaks.as_mut_slice(); + sorted_breaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + polars_ensure!(sorted_breaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "Breaks are not unique"); + + polars_ensure!(sorted_breaks[0] > f64::NEG_INFINITY, ComputeError: "Don't include -inf in breaks"); + polars_ensure!(sorted_breaks[sorted_breaks.len() - 1] < f64::INFINITY, ComputeError: "Don't include inf in breaks"); + + let cutlabs = match labels { + Some(ll) => { + polars_ensure!(ll.len() == sorted_breaks.len() + 1, ShapeMismatch: "Provide nbreaks + 1 labels"); + ll + } + None => (once(&f64::NEG_INFINITY).chain(sorted_breaks.iter())) + .zip(sorted_breaks.iter().chain(once(&f64::INFINITY))) + .map(|v| { + if left_closed { + format!("[{}, {})", v.0, v.1) + } else { + format!("({}, {}]", v.0, v.1) + } + }) + .collect::>(), + }; + + let cl: Vec<&str> = cutlabs.iter().map(String::as_str).collect(); + let s_flt = s.cast(&DataType::Float64)?; + let bin_iter = s_flt.f64()?.into_iter(); + + let out_name = format!("{}_bin", s.name()); + let mut bld = CategoricalChunkedBuilder::new(&out_name, s.len()); + unsafe { + if left_closed { + bld.drain_iter(bin_iter.map(|opt| { + opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x >= v))) + })); + } else { + bld.drain_iter(bin_iter.map(|opt| { + opt.map(|x| *cl.get_unchecked(sorted_breaks.partition_point(|&v| x > v))) + })); + } + } + Ok(bld.finish().into_series()) +} + +pub fn qcut( + s: &Series, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, +) -> PolarsResult { + let s = s.cast(&DataType::Float64)?; + let s2 = s.sort(false); + let ca = s2.f64()?; + let f = |&p| { + ca.quantile(p, QuantileInterpolOptions::Linear) + .unwrap() + .unwrap() + }; + let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); + qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + // When probs are spaced too closely for the number of repeated values in the distribution + // some quantiles may be duplicated. The only thing to do if we want to go on, is to drop + // the repeated values and live with some bins being larger than intended. + if allow_duplicates { + let lfilt = match labels { + None => None, + Some(ll) => { + polars_ensure!(ll.len() == qbreaks.len() + 1, + ShapeMismatch: "Wrong number of labels"); + let blen = ll.len(); + Some( + ll.into_iter() + .enumerate() + .filter(|(i, _)| *i == 0 || *i == blen || qbreaks[*i] != qbreaks[i - 1]) + .unzip::<_, _, Vec<_>, Vec<_>>() + .1, + ) + } + }; + qbreaks.dedup(); + return cut(&s, qbreaks, lfilt, left_closed); + } + cut(&s, qbreaks, labels, left_closed) +} diff --git a/polars/polars-ops/src/series/ops/mod.rs b/polars/polars-ops/src/series/ops/mod.rs index 04ce2a59487d..0b06f21cc20f 100644 --- a/polars/polars-ops/src/series/ops/mod.rs +++ b/polars/polars-ops/src/series/ops/mod.rs @@ -2,6 +2,8 @@ mod approx_algo; #[cfg(feature = "approx_unique")] mod approx_unique; mod arg_min_max; +#[cfg(feature = "cutqcut")] +mod cut; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] @@ -24,6 +26,8 @@ pub use approx_algo::*; #[cfg(feature = "approx_unique")] pub use approx_unique::*; pub use arg_min_max::ArgAgg; +#[cfg(feature = "cutqcut")] +pub use cut::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 56a8691b092c..4bfd2818c5f7 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -68,6 +68,7 @@ list_count = ["polars/list_count"] binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] +cutqcut = ["polars/cutqcut"] all = [ "json", @@ -105,6 +106,7 @@ all = [ "list_count", "list_sets", "list_any_all", + "cutqcut", ] # we cannot conditionally activate simd diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index d2db83d3bafc..5a7a22e4c378 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3234,6 +3234,57 @@ def quantile( quantile = parse_as_expression(quantile) return self._from_pyexpr(self._pyexpr.quantile(quantile, interpolation)) + def cut( + self, + breaks: list[float], + labels: list[str] | None = None, + left_closed: bool = False, + ) -> Self: + """ + Bin continuous values into discrete categories. + + Parameters + ---------- + breaks + A list of unique cut points. + labels + Labels to assign to bins. If given, the length must be len(probs) + 1. + left_closed + Whether intervals should be [) instead of the default of (] + + """ + return self._from_pyexpr(self._pyexpr.cut(breaks, labels, left_closed)) + + def qcut( + self, + probs: list[float], + labels: list[str] | None = None, + left_closed: bool = False, + allow_duplicates: bool = False, + ) -> Self: + """ + Bin continuous values into discrete categories based on their quantiles. + + Parameters + ---------- + probs + Probabilities for which to find the corresponding quantiles + For p in probs, we assume 0 <= p <= 1 + labels + Labels to assign to bins. If given, the length must be len(probs) + 1. + If computing over a grouping variable we recommend this be set. + left_closed + Whether intervals should be [) instead of the default of (] + allow_duplicates + If True, the resulting quantile breaks don't have to be unique. This can + happen even with unique probs depending on the data. Duplicates will be + dropped, resulting in fewer bins. + + """ + return self._from_pyexpr( + self._pyexpr.qcut(probs, labels, left_closed, allow_duplicates) + ) + def filter(self, predicate: Expr) -> Self: """ Filter a single column. diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 5a13bafcc19d..a6841f51ad9b 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1605,9 +1605,11 @@ def cut( category_label: str = "category", *, maintain_order: bool = False, - ) -> DataFrame: + series: bool = False, + left_closed: bool = False, + ) -> DataFrame | Series: """ - Bin values into discrete values. + Bin continuous values into discrete categories. Parameters ---------- @@ -1617,15 +1619,19 @@ def cut( Labels to assign to the bins. If given the length of labels must be len(bins) + 1. break_point_label - Name given to the breakpoint column. + Name given to the breakpoint column. Only used if series == False category_label - Name given to the category column. + Name given to the category column. Only used if series == False maintain_order - Keep the order of the original `Series`. + Keep the order of the original `Series`. Only used if series == False + series + If True, return the a categorical series in the data's original order + left_closed + Whether intervals should be [) instead of (] Returns ------- - DataFrame + DataFrame or Series Examples -------- @@ -1649,6 +1655,12 @@ def cut( └──────┴─────────────┴──────────────┘ """ + if series: + return ( + self.to_frame() + .select(F.col(self._s.name()).cut(bins, labels, left_closed)) + .to_series() + ) return wrap_df( self._s.cut( Series(break_point_label, bins, dtype=Float64)._s, @@ -1667,9 +1679,12 @@ def qcut( break_point_label: str = "break_point", category_label: str = "category", maintain_order: bool = False, - ) -> DataFrame: + series: bool = False, + left_closed: bool = False, + allow_duplicates: bool = False, + ) -> DataFrame | Series: """ - Bin values into discrete values based on their quantiles. + Bin continuous values into discrete categories based on their quantiles. Parameters ---------- @@ -1685,10 +1700,18 @@ def qcut( Name given to the category column. maintain_order Keep the order of the original `Series`. + series + If True, return the a categorical series in the data's original order + left_closed + Whether intervals should be [) instead of (] + allow_duplicates + If True, the resulting quantile breaks don't have to be unique. This can + happen even with unique probs depending on the data. Duplicates will be + dropped, resulting in fewer bins. Returns ------- - DataFrame + DataFrame or Series Warnings -------- @@ -1716,6 +1739,16 @@ def qcut( └──────┴─────────────┴───────────────┘ """ + if series: + return ( + self.to_frame() + .select( + F.col(self._s.name()).qcut( + quantiles, labels, left_closed, allow_duplicates + ) + ) + .to_series() + ) return wrap_df( self._s.qcut( Series(quantiles, dtype=Float64)._s, diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 47e498c155e1..c69505b2f785 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -177,6 +177,27 @@ impl PyExpr { .quantile(quantile.inner, interpolation.0) .into() } + + #[pyo3(signature = (breaks, labels, left_closed))] + #[cfg(feature = "cutqcut")] + fn cut(&self, breaks: Vec, labels: Option>, left_closed: bool) -> Self { + self.inner.clone().cut(breaks, labels, left_closed).into() + } + #[pyo3(signature = (probs, labels, left_closed, allow_duplicates))] + #[cfg(feature = "cutqcut")] + fn qcut( + &self, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + ) -> Self { + self.inner + .clone() + .qcut(probs, labels, left_closed, allow_duplicates) + .into() + } + fn agg_groups(&self) -> Self { self.clone().inner.agg_groups().into() } diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py index 426b069c01ed..6f9e5e81e672 100644 --- a/py-polars/tests/unit/operations/test_statistics.py +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -1,9 +1,11 @@ from datetime import timedelta +from typing import cast import numpy as np +from numpy import inf import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_corr() -> None: @@ -25,7 +27,7 @@ def test_corr() -> None: def test_cut() -> None: a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)]) - out = a.cut(bins=[-1, 1]) + out = cast(pl.DataFrame, a.cut(bins=[-1, 1])) assert out.shape == (12, 3) assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == { @@ -48,7 +50,7 @@ def test_cut() -> None: inf = float("inf") df = pl.DataFrame({"a": list(range(5))}) ser = df.select("a").to_series() - assert ser.cut(bins=[-1, 1]).rows() == [ + assert cast(pl.DataFrame, ser.cut(bins=[-1, 1])).rows() == [ (0.0, 1.0, "(-1.0, 1.0]"), (1.0, 1.0, "(-1.0, 1.0]"), (2.0, inf, "(1.0, inf]"), @@ -58,20 +60,68 @@ def test_cut() -> None: def test_cut_maintain_order() -> None: + expected_df = pl.DataFrame( + { + "a": [5.0, 8.0, 9.0, 5.0, 0.0, 0.0, 1.0, 7.0, 6.0, 9.0], + "break_point": [inf, inf, inf, inf, 1.0, 1.0, 1.0, inf, inf, inf], + "category": [ + "(1.0, inf]", + "(1.0, inf]", + "(1.0, inf]", + "(1.0, inf]", + "(-1.0, 1.0]", + "(-1.0, 1.0]", + "(-1.0, 1.0]", + "(1.0, inf]", + "(1.0, inf]", + "(1.0, inf]", + ], + } + ) np.random.seed(1) a = pl.Series("a", np.random.randint(0, 10, 10)) - out = a.cut(bins=[-1, 1], maintain_order=True) + out = cast(pl.DataFrame, a.cut(bins=[-1, 1], maintain_order=True)) + out_s = cast(pl.Series, a.cut(bins=[-1, 1], series=True)) assert out["a"].cast(int).series_equal(a) - assert ( - str(out.to_dict(False)) - == "{'a': [5.0, 8.0, 9.0, 5.0, 0.0, 0.0, 1.0, 7.0, 6.0, 9.0], 'break_point': [inf, inf, inf, inf, 1.0, 1.0, 1.0, inf, inf, inf], 'category': ['(1.0, inf]', '(1.0, inf]', '(1.0, inf]', '(1.0, inf]', '(-1.0, 1.0]', '(-1.0, 1.0]', '(-1.0, 1.0]', '(1.0, inf]', '(1.0, inf]', '(1.0, inf]']}" + # Compare strings and categoricals without a hassle + assert_frame_equal(expected_df, out, check_dtype=False) + # It formats differently + assert_series_equal( + pl.Series(["(1, inf]"] * 4 + ["(-1, 1]"] * 3 + ["(1, inf]"] * 3), + out_s, + check_dtype=False, + check_names=False, ) def test_qcut() -> None: - assert ( - str(pl.Series("a", range(-5, 3)).qcut([0.0, 0.25, 0.75]).to_dict(False)) - == "{'a': [-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0], 'break_point': [-5.0, -3.25, 0.25, 0.25, 0.25, 0.25, inf, inf], 'category': ['(-inf, -5.0]', '(-5.0, -3.25]', '(-3.25, 0.25]', '(-3.25, 0.25]', '(-3.25, 0.25]', '(-3.25, 0.25]', '(0.25, inf]', '(0.25, inf]']}" + input = pl.Series("a", range(-5, 3)) + exp = pl.DataFrame( + { + "a": [-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0], + "break_point": [-5.0, -3.25, 0.25, 0.25, 0.25, 0.25, inf, inf], + "category": [ + "(-inf, -5.0]", + "(-5.0, -3.25]", + "(-3.25, 0.25]", + "(-3.25, 0.25]", + "(-3.25, 0.25]", + "(-3.25, 0.25]", + "(0.25, inf]", + "(0.25, inf]", + ], + } + ) + out = cast(pl.DataFrame, input.qcut([0.0, 0.25, 0.75])) + out_s = cast(pl.Series, input.qcut([0.0, 0.25, 0.75], series=True)) + assert_frame_equal(out, exp, check_dtype=False) + assert_series_equal( + pl.Series( + ["(-inf, -5]", "(-5, -3.25]"] + ["(-3.25, 0.25]"] * 4 + ["(0.25, inf]"] * 2 + ), + out_s, + check_dtype=False, + check_names=False, ) @@ -85,12 +135,43 @@ def test_hist() -> None: def test_cut_null_values() -> None: s = pl.Series([-1.0, None, 1.0, 2.0, None, 8.0, 4.0]) - assert ( - str(s.qcut([0.2, 0.3], maintain_order=True).to_dict(False)) - == "{'': [-1.0, None, 1.0, 2.0, None, 8.0, 4.0], 'break_point': [0.5999999999999996, None, 1.2000000000000002, inf, None, inf, inf], 'category': ['(-inf, 0.5999999999999996]', None, '(0.5999999999999996, 1.2000000000000002]', '(1.2000000000000002, inf]', None, '(1.2000000000000002, inf]', '(1.2000000000000002, inf]']}" + exp = pl.DataFrame( + { + "": [-1.0, None, 1.0, 2.0, None, 8.0, 4.0], + "break_point": [ + 0.5999999999999996, + None, + 1.2000000000000002, + inf, + None, + inf, + inf, + ], + "category": [ + "(-inf, 0.5999999999999996]", + None, + "(0.5999999999999996, 1.2000000000000002]", + "(1.2000000000000002, inf]", + None, + "(1.2000000000000002, inf]", + "(1.2000000000000002, inf]", + ], + } ) + assert_frame_equal( + cast(pl.DataFrame, s.qcut([0.2, 0.3], maintain_order=True)), + exp, + check_dtype=False, + ) + assert_series_equal( + cast(pl.Series, s.qcut([0.2, 0.3], series=True)), + exp.get_column("category"), + check_dtype=False, + check_names=False, + ) + assert ( - str(s.qcut([0.2, 0.3], maintain_order=False).to_dict(False)) + str(cast(pl.DataFrame, s.qcut([0.2, 0.3], maintain_order=False)).to_dict(False)) == "{'': [-1.0, 1.0, 2.0, 4.0, 8.0, None, None], 'break_point': [0.5999999999999996, 1.2000000000000002, inf, inf, inf, None, None], 'category': ['(-inf, 0.5999999999999996]', '(0.5999999999999996, 1.2000000000000002]', '(1.2000000000000002, inf]', '(1.2000000000000002, inf]', '(1.2000000000000002, inf]', None, None]}" ) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 8a37394c822e..c8b8b43f842e 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -669,3 +669,15 @@ def test_sort_by_err_9259() -> None: df.lazy().groupby("c").agg( [pl.col("a").sort_by(pl.col("b").filter(pl.col("b") > 100)).sum()] ).collect() + + +def test_raise_cut_in_over() -> None: + with pl.StringCache(): + x = pl.Series(range(20)) + r = pl.Series( + [pl.repeat("a", 10, eager=True), pl.repeat("b", 10, eager=True)] + ).explode() + df = pl.DataFrame({"x": x, "g": r}) + + with pytest.raises(pl.ComputeError): + df.with_columns(pl.col("x").qcut([0.5]).over("g").to_physical())