From 0c0e086d61f304acbf4d29fb4f0df1b34465a10a Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 17 Apr 2024 16:55:31 +0800 Subject: [PATCH] fix: pass series name to apply for cut/qcut --- crates/polars-plan/src/dsl/mod.rs | 12 ++++++++++++ py-polars/tests/unit/operations/test_cut.py | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 671bdec3117a..cc7a9e484501 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1414,6 +1414,10 @@ impl Expr { left_closed, include_breaks, }) + .with_function_options(|mut opt| { + opt.pass_name_to_apply = true; + opt + }) } #[cfg(feature = "cutqcut")] @@ -1433,6 +1437,10 @@ impl Expr { allow_duplicates, include_breaks, }) + .with_function_options(|mut opt| { + opt.pass_name_to_apply = true; + opt + }) } #[cfg(feature = "cutqcut")] @@ -1453,6 +1461,10 @@ impl Expr { allow_duplicates, include_breaks, }) + .with_function_options(|mut opt| { + opt.pass_name_to_apply = true; + opt + }) } #[cfg(feature = "rle")] diff --git a/py-polars/tests/unit/operations/test_cut.py b/py-polars/tests/unit/operations/test_cut.py index 29c371adeb8d..b87381e94eff 100644 --- a/py-polars/tests/unit/operations/test_cut.py +++ b/py-polars/tests/unit/operations/test_cut.py @@ -110,3 +110,13 @@ def test_cut_deprecated_label_name() -> None: s.cut([0.1], category_label="x") with pytest.deprecated_call(): s.cut([0.1], break_point_label="x") + + +def test_cut_bin_name_in_agg_context() -> None: + df = pl.DataFrame({"a": [1]}).select( + cut=pl.col("a").cut([1, 2], include_breaks=True).over(1), + qcut=pl.col("a").qcut([1], include_breaks=True).over(1), + qcut_uniform=pl.col("a").qcut(1, include_breaks=True).over(1), + ) + schema = pl.Struct({"brk": pl.Float64, "a_bin": pl.Categorical("physical")}) + assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema}