From 10aee3589e5c6bb74e027e82c6bc4ffb0c08d046 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 28 Aug 2024 15:50:17 +0200 Subject: [PATCH] fix: Enable CSE in eager if struct are expanded (#18426) --- crates/polars-plan/src/frame/opt_state.rs | 2 - .../src/plans/conversion/dsl_to_ir.rs | 70 +++++++++++++------ .../src/plans/conversion/expr_expansion.rs | 45 ++++++++++-- crates/polars-plan/src/plans/optimizer/mod.rs | 11 ++- py-polars/tests/unit/test_cse.py | 12 ++++ 5 files changed, 108 insertions(+), 32 deletions(-) diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index d6ed31a12882..934f42e6109f 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -18,11 +18,9 @@ bitflags! { const FILE_CACHING = 1 << 6; /// Pushdown slices/limits. const SLICE_PUSHDOWN = 1 << 7; - #[cfg(feature = "cse")] /// Run common-subplan-elimination. This elides duplicate plans and caches their /// outputs. const COMM_SUBPLAN_ELIM = 1 << 8; - #[cfg(feature = "cse")] /// Run common-subexpression-elimination. This elides duplicate expressions and caches their /// outputs. const COMM_SUBEXPR_ELIM = 1 << 9; diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index b75b4d67b55b..d067b990765b 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -20,9 +20,10 @@ fn expand_expressions( exprs: Vec, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult> { let schema = lp_arena.get(input).schema(lp_arena); - let exprs = rewrite_projections(exprs, &schema, &[])?; + let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?; to_expr_irs(exprs, expr_arena) } @@ -57,17 +58,18 @@ pub fn to_alp( expr_arena: &mut Arena, lp_arena: &mut Arena, // Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected. - opt_state: &mut OptFlags, + opt_flags: &mut OptFlags, ) -> PolarsResult { let conversion_optimizer = ConversionOptimizer::new( - opt_state.contains(OptFlags::SIMPLIFY_EXPR), - opt_state.contains(OptFlags::TYPE_COERCION), + opt_flags.contains(OptFlags::SIMPLIFY_EXPR), + opt_flags.contains(OptFlags::TYPE_COERCION), ); let mut ctxt = ConversionContext { expr_arena, lp_arena, conversion_optimizer, + opt_flags, }; to_alp_impl(lp, &mut ctxt) @@ -77,6 +79,7 @@ struct ConversionContext<'a> { expr_arena: &'a mut Arena, lp_arena: &'a mut Arena, conversion_optimizer: ConversionOptimizer, + opt_flags: &'a mut OptFlags, } /// converts LogicalPlan to IR @@ -305,7 +308,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult { let mut input = to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(filter)))?; - let predicate = expand_filter(predicate, input, ctxt.lp_arena) + let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags) .map_err(|e| e.context(failed_here!(filter)))?; let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; @@ -378,8 +381,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult PolarsResult PolarsResult PolarsResult PolarsResult>(); - let (exprs, schema) = - resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena) - .map_err(|e| e.context(failed_here!(fill_nan)))?; + let (exprs, schema) = resolve_with_columns( + exprs, + input, + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(fill_nan)))?; ctxt.conversion_optimizer .fill_scratch(&exprs, ctxt.expr_arena); @@ -911,7 +932,12 @@ fn expand_scan_paths_with_hive_update( Ok(expanded_paths) } -fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena) -> PolarsResult { +fn expand_filter( + predicate: Expr, + input: Node, + lp_arena: &Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult { let schema = lp_arena.get(input).schema(lp_arena); let predicate = if has_expr(&predicate, |e| match e { Expr::Column(name) => is_regex_projection(name), @@ -924,7 +950,7 @@ fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena) -> PolarsRe | Expr::Nth(_) => true, _ => false, }) { - let mut rewritten = rewrite_projections(vec![predicate], &schema, &[])?; + let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?; match rewritten.len() { 1 => { // all good @@ -971,10 +997,11 @@ fn resolve_with_columns( input: Node, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, SchemaRef)> { let schema = lp_arena.get(input).schema(lp_arena); let mut new_schema = (**schema).clone(); - let (exprs, _) = prepare_projection(exprs, &schema)?; + let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?; let mut output_names = PlHashSet::with_capacity(exprs.len()); let mut arena = Arena::with_capacity(8); @@ -1008,10 +1035,11 @@ fn resolve_group_by( _options: &GroupbyOptions, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, Vec, SchemaRef)> { let current_schema = lp_arena.get(input).schema(lp_arena); let current_schema = current_schema.as_ref(); - let mut keys = rewrite_projections(keys, current_schema, &[])?; + let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?; // Initialize schema from keys let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?; @@ -1042,7 +1070,7 @@ fn resolve_group_by( } let keys_index_len = schema.len(); - let aggs = rewrite_projections(aggs, current_schema, &keys)?; + let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?; if pop_keys { let _ = keys.pop(); } diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index bcb6f957a7b2..d72fb0e00ed2 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -6,8 +6,9 @@ use super::*; pub(crate) fn prepare_projection( exprs: Vec, schema: &Schema, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, Schema)> { - let exprs = rewrite_projections(exprs, schema, &[])?; + let exprs = rewrite_projections(exprs, schema, &[], opt_flags)?; let schema = expressions_to_schema(&exprs, schema, Context::Default)?; Ok((exprs, schema)) } @@ -541,14 +542,18 @@ fn prepare_excluded( } // functions can have col(["a", "b"]) or col(String) as inputs -fn expand_function_inputs(expr: Expr, schema: &Schema) -> PolarsResult { +fn expand_function_inputs( + expr: Expr, + schema: &Schema, + opt_flags: &mut OptFlags, +) -> PolarsResult { expr.try_map_expr(|mut e| match &mut e { Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } if options .flags .contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) => { - *input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap(); + *input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags).unwrap(); if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) { // Needed to visualize the error *input = vec![Expr::Literal(LiteralValue::Null)]; @@ -639,12 +644,27 @@ fn find_flags(expr: &Expr) -> PolarsResult { }) } +#[cfg(feature = "dtype-struct")] +fn toggle_cse(opt_flags: &mut OptFlags) { + if opt_flags.contains(OptFlags::EAGER) { + #[cfg(debug_assertions)] + { + use polars_core::config::verbose; + if verbose() { + eprintln!("CSE turned on because of struct expansion") + } + } + *opt_flags |= OptFlags::COMM_SUBEXPR_ELIM; + } +} + /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns pub(crate) fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], + opt_flags: &mut OptFlags, ) -> PolarsResult> { let mut result = Vec::with_capacity(exprs.len() + schema.len()); @@ -653,7 +673,7 @@ pub(crate) fn rewrite_projections( let result_offset = result.len(); // Functions can have col(["a", "b"]) or col(String) as inputs. - expr = expand_function_inputs(expr, schema)?; + expr = expand_function_inputs(expr, schema, opt_flags)?; let mut flags = find_flags(&expr)?; if flags.has_selector { @@ -662,10 +682,11 @@ pub(crate) fn rewrite_projections( flags.multiple_columns = true; } - replace_and_add_to_results(expr, flags, &mut result, schema, keys)?; + replace_and_add_to_results(expr, flags, &mut result, schema, keys, opt_flags)?; #[cfg(feature = "dtype-struct")] if flags.has_struct_field_by_index { + toggle_cse(opt_flags); for e in &mut result[result_offset..] { *e = struct_index_to_field(std::mem::take(e), schema)?; } @@ -680,6 +701,7 @@ fn replace_and_add_to_results( result: &mut Vec, schema: &Schema, keys: &[Expr], + opt_flags: &mut OptFlags, ) -> PolarsResult<()> { if flags.has_nth { expr = replace_nth(expr, schema); @@ -732,6 +754,7 @@ fn replace_and_add_to_results( &mut intermediate, schema, keys, + opt_flags, )?; // Then expand the fields and add to the final result vec. @@ -739,12 +762,13 @@ fn replace_and_add_to_results( flags.multiple_columns = false; flags.has_wildcard = false; for e in intermediate { - replace_and_add_to_results(e, flags, result, schema, keys)?; + replace_and_add_to_results(e, flags, result, schema, keys, opt_flags)?; } } // has only field expansion // col('a').struct.field('*') else { + toggle_cse(opt_flags); expand_struct_fields(e, &expr, result, schema, names, &exclude)? } }, @@ -787,7 +811,14 @@ fn replace_selector_inner( match s { Selector::Root(expr) => { let local_flags = find_flags(&expr)?; - replace_and_add_to_results(*expr, local_flags, scratch, schema, keys)?; + replace_and_add_to_results( + *expr, + local_flags, + scratch, + schema, + keys, + &mut Default::default(), + )?; members.extend(scratch.drain(..)) }, Selector::Add(lhs, rhs) => { diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index 49dacbf7e6b8..4215347f2e7d 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -72,6 +72,13 @@ pub fn optimize( let opt = StackOptimizer {}; let mut rules: Vec> = Vec::with_capacity(8); + // Unset CSE + // This can be turned on again during ir-conversion. + #[allow(clippy::eq_op)] + #[cfg(feature = "cse")] + if opt_state.contains(OptFlags::EAGER) { + opt_state &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM); + } let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_state)?; // get toggle values @@ -87,10 +94,10 @@ pub fn optimize( // This keeps eager execution more snappy. let eager = opt_state.contains(OptFlags::EAGER); #[cfg(feature = "cse")] - let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM) && !eager; + let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM); #[cfg(feature = "cse")] - let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM) && !eager; + let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM); #[cfg(not(feature = "cse"))] let comm_subexpr_elim = false; diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index a330cda1b8a9..5a519dc94e2c 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -783,3 +783,15 @@ def test_cse_chunks_18124() -> None: ) .filter(pl.col("ts_diff") > 1) ).collect().shape == (4, 4) + + +def test_eager_cse_during_struct_expansion_18411() -> None: + df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]}) + vc = pl.col("foo").value_counts() + classes = vc.struct[0] + counts = vc.struct[1] + # Check if output is stable + assert ( + df.select(pl.col("foo").replace(classes, counts)) + == df.select(pl.col("foo").replace(classes, counts)) + )["foo"].all()