From 55ec6b69380dccec39b002a9f68248069b5b1d87 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Fri, 29 Nov 2024 03:22:10 +1100 Subject: [PATCH] c --- .../physical_plan/streaming/convert_alp.rs | 18 +- .../src/executors/projection.rs | 4 +- .../polars-mem-engine/src/executors/stack.rs | 4 +- crates/polars-mem-engine/src/planner/lp.rs | 23 +- crates/polars-plan/src/plans/aexpr/mod.rs | 57 +++-- .../polars-plan/src/plans/aexpr/traverse.rs | 2 +- crates/polars-plan/src/plans/aexpr/utils.rs | 226 ++++++++++-------- .../src/plans/conversion/dsl_to_ir.rs | 6 +- .../polars-plan/src/plans/conversion/join.rs | 11 +- crates/polars-plan/src/plans/lit.rs | 12 +- .../src/plans/optimizer/cache_states.rs | 2 +- .../src/plans/optimizer/cse/cse_expr.rs | 4 +- crates/polars-plan/src/plans/optimizer/mod.rs | 4 +- .../optimizer/predicate_pushdown/group_by.rs | 2 +- .../optimizer/predicate_pushdown/join.rs | 2 +- .../plans/optimizer/predicate_pushdown/mod.rs | 115 +++++---- .../optimizer/predicate_pushdown/utils.rs | 119 ++------- .../plans/optimizer/slice_pushdown_expr.rs | 2 +- .../src/plans/optimizer/slice_pushdown_lp.rs | 73 +++--- crates/polars-plan/src/plans/options.rs | 28 +-- crates/polars-plan/src/utils.rs | 1 + .../src/physical_plan/lower_expr.rs | 87 ++----- .../src/physical_plan/lower_ir.rs | 10 +- crates/polars-stream/src/skeleton.rs | 6 +- .../tests/unit/streaming/test_streaming.py | 19 ++ py-polars/tests/unit/test_predicates.py | 9 +- 26 files changed, 404 insertions(+), 442 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 6c84af4510b5..c9205b22d406 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -163,18 +163,16 @@ pub(crate) fn insert_streaming_nodes( execution_id += 1; match lp_arena.get(root) { Filter { input, predicate } - if is_streamable( - predicate.node(), - expr_arena, - IsStreamableContext::new(Default::default()), - ) => + if is_elementwise_rec(expr_arena.get(predicate.node()), expr_arena) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, HStack { input, exprs, .. } - if all_streamable(exprs, expr_arena, Default::default()) => + if exprs + .iter() + .all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); @@ -201,11 +199,9 @@ pub(crate) fn insert_streaming_nodes( stack.push(StackFrame::new(*input, state, current_idx)) }, Select { input, expr, .. } - if all_streamable( - expr, - expr_arena, - IsStreamableContext::new(Default::default()), - ) => + if expr + .iter() + .all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); diff --git a/crates/polars-mem-engine/src/executors/projection.rs b/crates/polars-mem-engine/src/executors/projection.rs index 3bb77f4d431d..bbab42740c5e 100644 --- a/crates/polars-mem-engine/src/executors/projection.rs +++ b/crates/polars-mem-engine/src/executors/projection.rs @@ -13,7 +13,7 @@ pub struct ProjectionExec { pub(crate) schema: SchemaRef, pub(crate) options: ProjectionOptions, // Can run all operations elementwise - pub(crate) streamable: bool, + pub(crate) allow_vertical_parallelism: bool, } impl ProjectionExec { @@ -23,7 +23,7 @@ impl ProjectionExec { mut df: DataFrame, ) -> PolarsResult { // Vertical and horizontal parallelism. - let df = if self.streamable + let df = if self.allow_vertical_parallelism && df.first_col_n_chunks() > 1 && df.height() > POOL.current_num_threads() * 2 && self.options.run_parallel diff --git a/crates/polars-mem-engine/src/executors/stack.rs b/crates/polars-mem-engine/src/executors/stack.rs index 0b2dbfd01da3..e48d7438e23c 100644 --- a/crates/polars-mem-engine/src/executors/stack.rs +++ b/crates/polars-mem-engine/src/executors/stack.rs @@ -11,7 +11,7 @@ pub struct StackExec { pub(crate) output_schema: SchemaRef, pub(crate) options: ProjectionOptions, // Can run all operations elementwise - pub(crate) streamable: bool, + pub(crate) allow_vertical_parallelism: bool, } impl StackExec { @@ -23,7 +23,7 @@ impl StackExec { let schema = &*self.output_schema; // Vertical and horizontal parallelism. - let df = if self.streamable + let df = if self.allow_vertical_parallelism && df.first_col_n_chunks() > 1 && df.height() > 0 && self.options.run_parallel diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 0d438b5f5bd1..ad58bb72e623 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -239,11 +239,8 @@ fn create_physical_plan_impl( Ok(Box::new(executors::SliceExec { input, offset, len })) }, Filter { input, predicate } => { - let mut streamable = is_streamable( - predicate.node(), - expr_arena, - IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), - ); + let mut streamable = + is_elementwise_rec_no_cat_cast(expr_arena.get(predicate.node()), expr_arena); let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); if streamable { // This can cause problems with string caches @@ -386,7 +383,7 @@ fn create_physical_plan_impl( &mut state, )?; - let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false)) + let allow_vertical_parallelism = options.should_broadcast && expr.iter().all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)) // If all columns are literal we would get a 1 row per thread. && !phys_expr.iter().all(|p| { p.is_literal() @@ -400,7 +397,7 @@ fn create_physical_plan_impl( #[cfg(test)] schema: _schema, options, - streamable, + allow_vertical_parallelism, })) }, Reduce { @@ -635,12 +632,10 @@ fn create_physical_plan_impl( let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - let streamable = options.should_broadcast - && all_streamable( - &exprs, - expr_arena, - IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), - ); + let allow_vertical_parallelism = options.should_broadcast + && exprs + .iter() + .all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)); let mut state = ExpressionConversionState::new( POOL.current_num_threads() > exprs.len(), @@ -661,7 +656,7 @@ fn create_physical_plan_impl( input_schema, output_schema, options, - streamable, + allow_vertical_parallelism, })) }, MapFunction { diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index e53dc50dc6d9..b0217e4c02ae 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -18,6 +18,7 @@ pub use scalar::is_scalar_ae; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; pub use traverse::*; +pub(crate) use utils::permits_filter_pushdown; pub use utils::*; use crate::constants::LEN; @@ -218,35 +219,41 @@ impl AExpr { pub(crate) fn col(name: PlSmallStr) -> Self { AExpr::Column(name) } - /// Any expression that is sensitive to the number of elements in a group - /// - Aggregations - /// - Sorts - /// - Counts - /// - .. - pub(crate) fn groups_sensitive(&self) -> bool { + + /// Checks whether this expression is elementwise. This only checks the top level expression. + pub(crate) fn is_elementwise_top_level(&self) -> bool { use AExpr::*; + match self { - Function { options, .. } | AnonymousFunction { options, .. } => { - options.is_groups_sensitive() - } - Sort { .. } - | SortBy { .. } - | Agg { .. } - | Window { .. } + AnonymousFunction { options, .. } => options.is_elementwise(), + + // Non-strict strptime must be done in-memory to ensure the format + // is consistent across the entire dataframe. + #[cfg(feature = "strings")] + Function { + options, + function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)), + .. + } => { + assert!(options.is_elementwise()); + opts.strict + }, + + Function { options, .. } => options.is_elementwise(), + + Literal(v) => v.projects_as_scalar(), + + Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true, + + Agg { .. } + | Explode(_) + | Filter { .. } + | Gather { .. } | Len | Slice { .. } - | Gather { .. } - => true, - Alias(_, _) - | Explode(_) - | Column(_) - | Literal(_) - // a caller should traverse binary and ternary - // to determine if the whole expr. is group sensitive - | BinaryExpr { .. } - | Ternary { .. } - | Cast { .. } - | Filter { .. } => false, + | Sort { .. } + | SortBy { .. } + | Window { .. } => false, } } diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 1697a5571d4e..20e7a454169c 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -2,7 +2,7 @@ use super::*; impl AExpr { /// Push nodes at this level to a pre-allocated stack. - pub(crate) fn nodes(&self, container: &mut C) { + pub(crate) fn nodes(&self, container: &mut impl PushNode) { use AExpr::*; match self { diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index 6520cc476178..90eaa474ceca 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -1,121 +1,147 @@ -use bitflags::bitflags; - use super::*; -fn has_series_or_range(ae: &AExpr) -> bool { - matches!( - ae, - AExpr::Literal(LiteralValue::Series(_) | LiteralValue::Range { .. }) - ) -} +/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will +/// be extended further with any nested expression nodes. +pub fn is_elementwise(stack: &mut Vec, ae: &AExpr, expr_arena: &Arena) -> bool { + use AExpr::*; -bitflags! { - #[derive(Default, Copy, Clone)] - struct StreamableFlags: u8 { - const ALLOW_CAST_CATEGORICAL = 1; - } -} + if !ae.is_elementwise_top_level() { + return false; + } -#[derive(Copy, Clone)] -pub struct IsStreamableContext { - flags: StreamableFlags, - context: Context, -} + match ae { + // Literals that aren't being projected are allowed to be non-scalar, so we don't add them + // for inspection. (e.g. `is_in()`). + #[cfg(feature = "is_in")] + Function { + function: FunctionExpr::Boolean(BooleanFunction::IsIn), + input, + .. + } => (|| { + if let Some(rhs) = input.get(1) { + assert_eq!(input.len(), 2); // A.is_in(B) + let rhs = rhs.node(); -impl Default for IsStreamableContext { - fn default() -> Self { - Self { - flags: StreamableFlags::all(), - context: Default::default(), - } + if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) { + stack.push_node(input[0].node()); + return; + } + }; + + ae.nodes(stack); + })(), + _ => ae.nodes(stack), } + + true } -impl IsStreamableContext { - pub fn new(ctx: Context) -> Self { - Self { - flags: StreamableFlags::all(), - context: ctx, +/// Recursive variant of `is_elementwise` +pub fn is_elementwise_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { + let mut stack = vec![]; + + loop { + if !is_elementwise(&mut stack, ae, expr_arena) { + return false; } - } - pub fn with_allow_cast_categorical(mut self, allow_cast_categorical: bool) -> Self { - self.flags.set( - StreamableFlags::ALLOW_CAST_CATEGORICAL, - allow_cast_categorical, - ); - self + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); } + + true } -pub fn is_streamable(node: Node, expr_arena: &Arena, ctx: IsStreamableContext) -> bool { - // check whether leaf column is Col or Lit - let mut seen_column = false; - let mut seen_lit_range = false; - let all = expr_arena.iter(node).all(|(_, ae)| match ae { - AExpr::Function { - function: FunctionExpr::SetSortedFlag(_), - .. - } => true, - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { - match ctx.context { - Context::Default => matches!( - options.collect_groups, - ApplyOptions::ElementWise | ApplyOptions::ApplyList - ), - Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), - } - }, - AExpr::Column(_) => { - seen_column = true; - true - }, - AExpr::BinaryExpr { left, right, .. } => { - !has_aexpr(*left, expr_arena, has_series_or_range) - && !has_aexpr(*right, expr_arena, has_series_or_range) - }, - AExpr::Ternary { - truthy, - falsy, - predicate, - } => { - !has_aexpr(*truthy, expr_arena, has_series_or_range) - && !has_aexpr(*falsy, expr_arena, has_series_or_range) - && !has_aexpr(*predicate, expr_arena, has_series_or_range) - }, +/// Recursive variant of `is_elementwise` that also forbids casting to categoricals. This function +/// is used to determine if an expression evaluation can be vertically parallelized. +pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { + let mut stack = vec![]; + + loop { + if !is_elementwise(&mut stack, ae, expr_arena) { + return false; + } + #[cfg(feature = "dtype-categorical")] - AExpr::Cast { dtype, .. } if matches!(dtype, DataType::Categorical(_, _)) => { - ctx.flags.contains(StreamableFlags::ALLOW_CAST_CATEGORICAL) - }, - AExpr::Alias(_, _) | AExpr::Cast { .. } => true, - AExpr::Literal(lv) => match lv { - LiteralValue::Series(_) | LiteralValue::Range { .. } => { - seen_lit_range = true; - true - }, - _ => true, - }, - _ => false, - }); - - if all { - // adding a range or literal series to chunks will fail because sizes don't match - // if column is a leaf column then it is ok - // - so we want to block `with_column(lit(Series))` - // - but we want to allow `with_column(col("foo").is_in(Series))` - // that means that IFF we seen a lit_range, we only allow if we also seen a `column`. - return if seen_lit_range { seen_column } else { true }; + { + if let AExpr::Cast { + dtype: DataType::Categorical(..), + .. + } = ae + { + return false; + } + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); } - false + true } -pub fn all_streamable( - exprs: &[ExprIR], +/// Check whether filters can be pushed past this expression. +/// +/// A query, `with_columns(C).filter(P)` can be re-ordered as `filter(P).with_columns(C)`, iff +/// both P and C permit filter pushdown. +/// +/// If filter pushdown is permitted, `stack` is extended with any input expression nodes that this +/// expression may have. +/// +/// Note that this function is not recursive - the caller should repeatedly +/// call this function with the `stack` to perform a recursive check. +pub(crate) fn permits_filter_pushdown( + stack: &mut Vec, + ae: &AExpr, expr_arena: &Arena, - ctx: IsStreamableContext, ) -> bool { - exprs - .iter() - .all(|e| is_streamable(e.node(), expr_arena, ctx)) + // This is a subset of an `is_elementwise` check that also blocks exprs that raise errors + // depending on the data. The idea is that, although the success value of these functions + // are elementwise, their error behavior is non-elementwise. Their error behavior is essentially + // performing an aggregation `ANY(evaluation_result_was_error)`, and if this is the case then + // the query result should be an error. + match ae { + // Rows that go OOB on get/gather may be filtered out in earlier operations, + // so we don't push these down. + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Get(false)), + .. + } => false, + #[cfg(feature = "list_gather")] + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Gather(false)), + .. + } => false, + #[cfg(feature = "dtype-array")] + AExpr::Function { + function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)), + .. + } => false, + // TODO: There are a lot more functions that should be caught here. + ae => is_elementwise(stack, ae, expr_arena), + } +} + +pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { + let mut stack = vec![]; + + loop { + if !permits_filter_pushdown(&mut stack, ae, expr_arena) { + return false; + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); + } + + true } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index b323b3e3e705..34f03e6debdd 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -406,7 +406,11 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; - return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Default::default()) { + // TODO: We could do better here by using `pushdown_eligibility()` + return if permits_filter_pushdown_rec( + ctxt.expr_arena.get(predicate_ae.node()), + ctxt.expr_arena, + ) { // Split expression that are ANDed into multiple Filter nodes as the optimizer can then // push them down independently. Especially if they refer columns from different tables // this will be more performant. diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 4d38bb6b158f..8cefab436078 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -134,12 +134,17 @@ pub fn resolve_join( } // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. - let all_elementwise = - |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Default::default()); + let all_elementwise = |aexprs: &[ExprIR]| { + aexprs + .iter() + .all(|e| is_elementwise_rec(ctxt.expr_arena.get(e.node()), ctxt.expr_arena)) + }; + polars_ensure!( all_elementwise(&left_on) && all_elementwise(&right_on), - InvalidOperation: "All join key expressions must be elementwise." + InvalidOperation: "all join key expressions must be elementwise." ); + let lp = IR::Join { input_left, input_right, diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index 74feffd60da0..56fa06b457e8 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -98,10 +98,16 @@ impl LiteralValue { } } + pub fn is_scalar(&self) -> bool { + !matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. }) + } + + /// Less-strict `is_scalar` check - generally used for internal functionality such as our + /// optimizers. pub(crate) fn projects_as_scalar(&self) -> bool { match self { - LiteralValue::Range { low, high, .. } => high.saturating_sub(*low) == 1, LiteralValue::Series(s) => s.len() == 1, + LiteralValue::Range { low, high, .. } => high.saturating_sub(*low) == 1, _ => true, } } @@ -230,10 +236,6 @@ impl LiteralValue { LiteralValue::UInt32(value) } } - - pub fn is_scalar(&self) -> bool { - !matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. }) - } } pub trait Literal { diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs index f6968cc6f7d2..0162cdfc0645 100644 --- a/crates/polars-plan/src/plans/optimizer/cache_states.rs +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -290,7 +290,7 @@ pub(super) fn set_cache_states( // back to the cache node again if !cache_schema_and_children.is_empty() { let mut proj_pd = ProjectionPushDown::new(); - let pred_pd = PredicatePushDown::new(expr_eval).block_at_cache(false); + let mut pred_pd = PredicatePushDown::new(expr_eval).block_at_cache(false); for (_cache_id, v) in cache_schema_and_children { // # CHECK IF WE NEED TO REMOVE CACHES // If we encounter multiple predicates we remove the cache nodes completely as we don't diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs index 700a82720eb4..ee0b9472ac01 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs @@ -350,11 +350,11 @@ impl ExprIdentifierVisitor<'_> { // other operations we cannot add to the state as they have the output size of the // groups, not the original dataframe if self.is_group_by { - if ae.groups_sensitive() { + if !ae.is_elementwise_top_level() { return REFUSE_NO_MEMBER; } match ae { - AExpr::AnonymousFunction { .. } | AExpr::Filter { .. } => REFUSE_NO_MEMBER, + AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER, AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER, _ => ACCEPT, } diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index dc0d330d8b86..4b182b835e53 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -157,7 +157,7 @@ pub fn optimize( } if predicate_pushdown { - let predicate_pushdown_opt = PredicatePushDown::new(expr_eval); + let mut predicate_pushdown_opt = PredicatePushDown::new(expr_eval); let alp = lp_arena.take(lp_top); let alp = predicate_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); @@ -182,7 +182,7 @@ pub fn optimize( } if slice_pushdown { - let slice_pushdown_opt = SlicePushDown::new(streaming, new_streaming); + let mut slice_pushdown_opt = SlicePushDown::new(streaming, new_streaming); let alp = lp_arena.take(lp_top); let alp = slice_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs index 6c6d4460b29e..fd7eba7b9118 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs @@ -2,7 +2,7 @@ use super::*; #[allow(clippy::too_many_arguments)] pub(super) fn process_group_by( - opt: &PredicatePushDown, + opt: &mut PredicatePushDown, lp_arena: &mut Arena, expr_arena: &mut Arena, input: Node, diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index 3437bbc8b3e1..6449e63ad4ed 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -121,7 +121,7 @@ fn predicate_applies_to_both_tables( #[allow(clippy::too_many_arguments)] pub(super) fn process_join( - opt: &PredicatePushDown, + opt: &mut PredicatePushDown, lp_arena: &mut Arena, expr_arena: &mut Arena, input_left: Node, diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index f9644b1aef0c..6c99af306c34 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -19,28 +19,47 @@ use crate::utils::{check_input_node, has_aexpr}; pub type ExprEval<'a> = Option<&'a dyn Fn(&ExprIR, &Arena, &SchemaRef) -> Option>>; -pub struct PredicatePushDown<'a> { - expr_eval: ExprEval<'a>, - verbose: bool, - block_at_cache: bool, -} +/// The struct is wrapped in a mod to prevent direct member access of `nodes_scratch` +mod inner { + use polars_core::config::verbose; + use polars_utils::arena::Node; + + use super::ExprEval; + + pub struct PredicatePushDown<'a> { + pub(super) expr_eval: ExprEval<'a>, + pub(super) verbose: bool, + pub(super) block_at_cache: bool, + nodes_scratch: Vec, + } + + impl<'a> PredicatePushDown<'a> { + pub fn new(expr_eval: ExprEval<'a>) -> Self { + Self { + expr_eval, + verbose: verbose(), + block_at_cache: true, + nodes_scratch: vec![], + } + } -impl<'a> PredicatePushDown<'a> { - pub fn new(expr_eval: ExprEval<'a>) -> Self { - Self { - expr_eval, - verbose: verbose(), - block_at_cache: true, + pub(super) fn nodes_scratch_mut(&mut self) -> &mut Vec { + self.nodes_scratch.clear(); + &mut self.nodes_scratch } } +} +pub use inner::PredicatePushDown; + +impl PredicatePushDown<'_> { pub(crate) fn block_at_cache(mut self, toggle: bool) -> Self { self.block_at_cache = toggle; self } fn optional_apply_predicate( - &self, + &mut self, lp: IR, local_predicates: Vec, lp_arena: &mut Arena, @@ -57,7 +76,7 @@ impl<'a> PredicatePushDown<'a> { } fn pushdown_and_assign( - &self, + &mut self, input: Node, acc_predicates: PlHashMap, lp_arena: &mut Arena, @@ -71,7 +90,7 @@ impl<'a> PredicatePushDown<'a> { /// Filter will be pushed down. fn pushdown_and_continue( - &self, + &mut self, lp: IR, mut acc_predicates: PlHashMap, lp_arena: &mut Arena, @@ -89,8 +108,13 @@ impl<'a> PredicatePushDown<'a> { } let input = inputs[inputs.len() - 1]; - let (eligibility, alias_rename_map) = - pushdown_eligibility(&exprs, &[], &acc_predicates, expr_arena)?; + let (eligibility, alias_rename_map) = pushdown_eligibility( + &exprs, + &[], + &acc_predicates, + expr_arena, + self.nodes_scratch_mut(), + )?; let local_predicates = match eligibility { PushdownEligibility::Full => vec![], @@ -186,7 +210,7 @@ impl<'a> PredicatePushDown<'a> { /// Filter will be done at this node, but we continue optimization fn no_pushdown_restart_opt( - &self, + &mut self, lp: IR, acc_predicates: PlHashMap, lp_arena: &mut Arena, @@ -217,7 +241,7 @@ impl<'a> PredicatePushDown<'a> { } fn no_pushdown( - &self, + &mut self, lp: IR, acc_predicates: PlHashMap, lp_arena: &mut Arena, @@ -241,7 +265,7 @@ impl<'a> PredicatePushDown<'a> { /// * `expr_arena` - The local memory arena for the expressions. #[recursive] fn push_down( - &self, + &mut self, lp: IR, mut acc_predicates: PlHashMap, lp_arena: &mut Arena, @@ -267,9 +291,10 @@ impl<'a> PredicatePushDown<'a> { let local_predicates = match pushdown_eligibility( &[], - &[(tmp_key.clone(), predicate.clone())], + &[predicate.clone()], &acc_predicates, expr_arena, + self.nodes_scratch_mut(), )? .0 { @@ -632,34 +657,32 @@ impl<'a> PredicatePushDown<'a> { acc_predicates, ), lp @ Union { .. } => { - let mut local_predicates = vec![]; - - // a count is influenced by a Union/Vstack - acc_predicates.retain(|_, predicate| { - if has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len)) { - local_predicates.push(predicate.clone()); - false - } else { - true + if cfg!(debug_assertions) { + for v in acc_predicates.values() { + let ae = expr_arena.get(v.node()); + assert!(permits_filter_pushdown( + self.nodes_scratch_mut(), + ae, + expr_arena + )); } - }); - let lp = - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) }, lp @ Sort { .. } => { - let mut local_predicates = vec![]; - acc_predicates.retain(|_, predicate| { - if predicate_is_sort_boundary(predicate.node(), expr_arena) { - local_predicates.push(predicate.clone()); - false - } else { - true + if cfg!(debug_assertions) { + for v in acc_predicates.values() { + let ae = expr_arena.get(v.node()); + assert!(permits_filter_pushdown( + self.nodes_scratch_mut(), + ae, + expr_arena + )); } - }); - let lp = - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) }, // Pushed down passed these nodes lp @ Sink { .. } => { @@ -694,7 +717,7 @@ impl<'a> PredicatePushDown<'a> { if let Some(predicate) = predicate { // For IO plugins we only accept streamable expressions as // we want to apply the predicates to the batches. - if !is_streamable(predicate.node(), expr_arena, Default::default()) + if !is_elementwise_rec(expr_arena.get(predicate.node()), expr_arena) && matches!(options.python_source, PythonScanSource::IOPlugin) { let lp = PythonScan { options }; @@ -715,7 +738,7 @@ impl<'a> PredicatePushDown<'a> { } pub(crate) fn optimize( - &self, + &mut self, logical_plan: IR, lp_arena: &mut Arena, expr_arena: &mut Arena, diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs index 7f14f2269cfd..59f9346ad8c7 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs @@ -79,34 +79,6 @@ pub(super) fn predicate_at_scan( } } -fn shifts_elements(node: Node, expr_arena: &Arena) -> bool { - let matches = |e: &AExpr| { - matches!( - e, - AExpr::Function { - function: FunctionExpr::Shift | FunctionExpr::ShiftAndFill, - .. - } - ) - }; - has_aexpr(node, expr_arena, matches) -} - -pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) -> bool { - let matches = |e: &AExpr| match e { - AExpr::Window { function, .. } => shifts_elements(*function, expr_arena), - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { - // this check for functions that are - // group sensitive and doesn't auto-explode (e.g. is a reduction/aggregation - // like sum, min, etc). - // function that match this are `cum_sum`, `shift`, `sort`, etc. - options.is_groups_sensitive() && !options.flags.contains(FunctionFlags::RETURNS_SCALAR) - }, - _ => false, - }; - has_aexpr(node, expr_arena, matches) -} - /// Evaluates a condition on the column name inputs of every predicate, where if /// the condition evaluates to true on any column name the predicate is /// transferred to local. @@ -138,73 +110,6 @@ where local_predicates } -/// Extends a stack of nodes with new nodes from `ae` (with some filtering), to support traversing -/// an expression tree to check predicate PD eligibility. Generally called repeatedly with the same -/// stack until all nodes are exhausted. -fn check_and_extend_predicate_pd_nodes( - stack: &mut Vec, - ae: &AExpr, - expr_arena: &Arena, -) -> bool { - if match ae { - // These literals do not come from the RHS of an is_in, meaning that - // they are projected as either columns or predicates, both of which - // rely on the height of the dataframe at this level and thus need - // to block pushdown. - AExpr::Literal(lit) => !lit.projects_as_scalar(), - // Rows that go OOB on get/gather may be filtered out in earlier operations, - // so we don't push these down. - AExpr::Function { - function: FunctionExpr::ListExpr(ListFunction::Get(false)), - .. - } => true, - #[cfg(feature = "list_gather")] - AExpr::Function { - function: FunctionExpr::ListExpr(ListFunction::Gather(false)), - .. - } => true, - #[cfg(feature = "dtype-array")] - AExpr::Function { - function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)), - .. - } => true, - ae => ae.groups_sensitive(), - } { - false - } else { - match ae { - #[cfg(feature = "is_in")] - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsIn), - input, - .. - } => { - // Handles a special case where the expr contains a series, but it is being - // used as part the RHS of an `is_in`, so it can be pushed down as it is not - // being projected. - let mut transferred_local_nodes = false; - if let Some(rhs) = input.get(1) { - let rhs = rhs.node(); - if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) { - let mut local_nodes = Vec::::with_capacity(4); - ae.nodes(&mut local_nodes); - - stack.extend(local_nodes.into_iter().filter(|node| *node != rhs)); - transferred_local_nodes = true; - } - }; - if !transferred_local_nodes { - ae.nodes(stack); - } - }, - ae => { - ae.nodes(stack); - }, - }; - true - } -} - /// * `col(A).alias(B).alias(C) => (C, A)` /// * `col(A) => (A, A)` /// * `col(A).sum().alias(B) => None` @@ -235,11 +140,13 @@ pub enum PushdownEligibility { #[allow(clippy::type_complexity)] pub fn pushdown_eligibility( projection_nodes: &[ExprIR], - new_predicates: &[(PlSmallStr, ExprIR)], + new_predicates: &[ExprIR], acc_predicates: &PlHashMap, expr_arena: &mut Arena, + scratch: &mut Vec, ) -> PolarsResult<(PushdownEligibility, PlHashMap)> { - let mut ae_nodes_stack = Vec::::with_capacity(4); + assert!(scratch.is_empty()); + let ae_nodes_stack = scratch; let mut alias_to_col_map = optimizer::init_hashmap::(Some(projection_nodes.len())); @@ -259,6 +166,8 @@ pub fn pushdown_eligibility( common_window_inputs: &mut PlHashSet| { debug_assert_eq!(ae_nodes_stack.len(), 1); + let mut partition_by_names = PlHashSet::::new(); + while let Some(node) = ae_nodes_stack.pop() { let ae = expr_arena.get(node); @@ -276,8 +185,8 @@ pub fn pushdown_eligibility( return false; }; - let mut partition_by_names = - PlHashSet::::with_capacity(partition_by.len()); + partition_by_names.clear(); + partition_by_names.reserve(partition_by.len()); for node in partition_by.iter() { // Only accept col() @@ -295,7 +204,7 @@ pub fn pushdown_eligibility( } if !*has_window { - for name in partition_by_names.into_iter() { + for name in partition_by_names.drain() { common_window_inputs.insert(name); } @@ -313,7 +222,7 @@ pub fn pushdown_eligibility( } }, _ => { - if !check_and_extend_predicate_pd_nodes(ae_nodes_stack, ae, expr_arena) { + if !permits_filter_pushdown(ae_nodes_stack, ae, expr_arena) { return false; } }, @@ -340,7 +249,7 @@ pub fn pushdown_eligibility( ae_nodes_stack.push(e.node()); if !process_projection_or_predicate( - &mut ae_nodes_stack, + ae_nodes_stack, &mut has_window, &mut common_window_inputs, ) { @@ -372,12 +281,12 @@ pub fn pushdown_eligibility( common_window_inputs = new; } - for (_, e) in new_predicates.iter() { + for e in new_predicates.iter() { debug_assert!(ae_nodes_stack.is_empty()); ae_nodes_stack.push(e.node()); if !process_projection_or_predicate( - &mut ae_nodes_stack, + ae_nodes_stack, &mut has_window, &mut common_window_inputs, ) { @@ -417,7 +326,7 @@ pub fn pushdown_eligibility( can_use_column(name) } else { // May still contain window expressions that need to be blocked. - check_and_extend_predicate_pd_nodes(&mut ae_nodes_stack, ae, expr_arena) + permits_filter_pushdown(ae_nodes_stack, ae, expr_arena) }; if !can_pushdown { diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs index 9d23552d8a1f..5d958d940ef0 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs @@ -31,7 +31,7 @@ impl OptimizationRule for SlicePushDown { let ae = ae.clone(); self.scratch.clear(); ae.nodes(&mut self.scratch); - let input = self.scratch[0]; + let input = self.scratch.drain(..).next().unwrap(); let new_input = pushdown(input, offset, length, expr_arena); Some(ae.replace_inputs(&[new_input])) }, diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index 6ecae038f49f..33d0c6a777a6 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -20,10 +20,20 @@ struct State { /// * at least 1 projection is based on a column (for height broadcast) /// * projections not based on any column project as scalars /// -/// Returns (all_elementwise, all_elementwise_and_any_expr_has_column) -fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena) -> (bool, bool) { - let mut all_elementwise_and_any_expr_has_column = false; +/// Returns (can_pushdown, can_pushdown_and_any_expr_has_column) +fn can_pushdown_slice_past_projections( + exprs: &[ExprIR], + arena: &Arena, + scratch: &mut Vec, +) -> (bool, bool) { + assert!(scratch.is_empty()); + + let mut can_pushdown_and_any_expr_has_column = false; + for expr_ir in exprs.iter() { + scratch.push(expr_ir.node()); + + // # "has_column" // `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown, // because `c` projects to a height independent from the input height. We check // this by observing that `c` does not have any columns in its input nodes. @@ -32,31 +42,35 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena) - // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, // `str.contains`, `str.contains_many` etc. - observe a column node is present // but the output height is not dependent on it. - let is_elementwise = is_streamable(expr_ir.node(), arena, Default::default()); - let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold( - (false, true), - |(has_column, lit_scalar), (_node, ae)| { - ( - has_column | matches!(ae, AExpr::Column(_)), - lit_scalar - & if let AExpr::Literal(v) = ae { - v.projects_as_scalar() - } else { - true - }, - ) - }, - ); + let mut has_column = false; + let mut literals_all_scalar = true; + + while let Some(node) = scratch.pop() { + let ae = arena.get(node); + + // We re-use the logic from predicate pushdown, as slices can be seen as a form of filtering. + // But we also do some bookkeeping here specific to slice pushdown. + + match ae { + AExpr::Column(_) => has_column = true, + AExpr::Literal(v) => literals_all_scalar &= v.projects_as_scalar(), + _ => {}, + } + + if !permits_filter_pushdown(scratch, ae, arena) { + return (false, false); + } + } // If there is no column then all literals must be scalar - if !is_elementwise || !(has_column || literals_all_scalar) { + if !(has_column || literals_all_scalar) { return (false, false); } - all_elementwise_and_any_expr_has_column |= has_column + can_pushdown_and_any_expr_has_column |= has_column } - (true, all_elementwise_and_any_expr_has_column) + (true, can_pushdown_and_any_expr_has_column) } impl SlicePushDown { @@ -93,7 +107,7 @@ impl SlicePushDown { /// slice will be done at this node, but we continue optimization fn no_pushdown_restart_opt( - &self, + &mut self, lp: IR, state: Option, lp_arena: &mut Arena, @@ -120,7 +134,7 @@ impl SlicePushDown { /// slice will be pushed down. fn pushdown_and_continue( - &self, + &mut self, lp: IR, state: Option, lp_arena: &mut Arena, @@ -143,7 +157,7 @@ impl SlicePushDown { #[recursive] fn pushdown( - &self, + &mut self, lp: IR, state: Option, lp_arena: &mut Arena, @@ -473,7 +487,7 @@ impl SlicePushDown { } // there is state, inspect the projection to determine how to deal with it (Select {input, expr, schema, options}, Some(_)) => { - if can_pushdown_slice_past_projections(&expr, expr_arena).1 { + if can_pushdown_slice_past_projections(&expr, expr_arena, &mut self.scratch).1 { let lp = Select {input, expr, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } @@ -484,14 +498,13 @@ impl SlicePushDown { } } (HStack {input, exprs, schema, options}, _) => { - let (can_pushdown, all_elementwise_and_any_expr_has_column) = can_pushdown_slice_past_projections(&exprs, expr_arena); + let (can_pushdown, can_pushdown_and_any_expr_has_column) = can_pushdown_slice_past_projections(&exprs, expr_arena, &mut self.scratch); - if ( - // If the schema length is greater than an input column is being projected, so + if can_pushdown_and_any_expr_has_column || ( + // If the schema length is greater then an input column is being projected, so // the exprs in with_columns do not need to have an input column name. schema.len() > exprs.len() && can_pushdown ) - || all_elementwise_and_any_expr_has_column // e.g. select(c).with_columns(c = c + 1) { let lp = HStack {input, exprs, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) @@ -514,7 +527,7 @@ impl SlicePushDown { } pub fn optimize( - &self, + &mut self, logical_plan: IR, lp_arena: &mut Arena, expr_arena: &mut Arena, diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index f0df191d395f..8d33d031dfa2 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -106,13 +106,13 @@ pub struct DistinctOptionsIR { pub enum ApplyOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. - // e.g. [g1, g1, g2] -> [[g1, g1], g2] + /// e.g. [g1, g1, g2] -> [[g1, g1], g2] GroupWise, - // collect groups to a list and then apply - // e.g. [g1, g1, g2] -> list([g1, g1, g2]) + /// collect groups to a list and then apply + /// e.g. [g1, g1, g2] -> list([g1, g1, g2]) ApplyList, - // do not collect before apply - // e.g. [g1, g1, g2] -> [g1, g1, g2] + /// do not collect before apply + /// e.g. [g1, g1, g2] -> [g1, g1, g2] ElementWise, } @@ -200,14 +200,6 @@ pub struct FunctionOptions { } impl FunctionOptions { - /// Any function that is sensitive to the number of elements in a group - /// - Aggregations - /// - Sorts - /// - Counts - pub fn is_groups_sensitive(&self) -> bool { - matches!(self.collect_groups, ApplyOptions::GroupWise) - } - #[cfg(feature = "fused")] pub(crate) unsafe fn no_check_lengths(&mut self) { self.check_lengths = UnsafeBool(false); @@ -217,10 +209,12 @@ impl FunctionOptions { } pub fn is_elementwise(&self) -> bool { - self.collect_groups == ApplyOptions::ElementWise - && !self - .flags - .contains(FunctionFlags::CHANGES_LENGTH | FunctionFlags::RETURNS_SCALAR) + matches!( + self.collect_groups, + ApplyOptions::ElementWise | ApplyOptions::ApplyList + ) && !self + .flags + .contains(FunctionFlags::CHANGES_LENGTH | FunctionFlags::RETURNS_SCALAR) } } diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 4bdf483474c3..fdf9c979738a 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -40,6 +40,7 @@ pub(crate) fn fmt_column_delimited>( write!(f, "{container_end}") } +// TODO: Remove this and use `Extend` instead. pub trait PushNode { fn push_node(&mut self, value: Node); diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 3af80df16f9f..d9ae84fa121b 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -9,7 +9,7 @@ use polars_expr::planner::get_expr_depth_limit; use polars_expr::state::ExecutionState; use polars_expr::{create_physical_expr, ExpressionConversionState}; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, LiteralValue}; +use polars_plan::plans::AExpr; use polars_plan::prelude::*; use polars_utils::arena::{Arena, Node}; use polars_utils::format_pl_smallstr; @@ -47,73 +47,36 @@ struct LowerExprContext<'a> { cache: &'a mut ExprCache, } -#[recursive::recursive] -pub(crate) fn is_elementwise( +pub(crate) fn is_elementwise_rec_cached( expr_key: IRNodeKey, arena: &Arena, cache: &mut ExprCache, ) -> bool { - if let Some(ret) = cache.is_elementwise.get(&expr_key) { - return *ret; - } + if !cache.is_elementwise.contains_key(&expr_key) { + cache.is_elementwise.insert( + expr_key, + (|| { + let mut expr_key = expr_key; + let mut stack = vec![]; + + loop { + if !polars_plan::plans::is_elementwise(&mut stack, arena.get(expr_key), arena) { + return false; + } + + let Some(next_key) = stack.pop() else { + break; + }; - let ret = match arena.get(expr_key) { - AExpr::Explode(_) => false, - AExpr::Alias(inner, _) => is_elementwise(*inner, arena, cache), - AExpr::Column(_) => true, - AExpr::Literal(lit) => !matches!(lit, LiteralValue::Series(_) | LiteralValue::Range { .. }), - AExpr::BinaryExpr { left, op: _, right } => { - is_elementwise(*left, arena, cache) && is_elementwise(*right, arena, cache) - }, - AExpr::Cast { - expr, - dtype: _, - options: _, - } => is_elementwise(*expr, arena, cache), - AExpr::Sort { .. } | AExpr::SortBy { .. } | AExpr::Gather { .. } => false, - AExpr::Filter { .. } => false, - AExpr::Agg(_) => false, - AExpr::Ternary { - predicate, - truthy, - falsy, - } => { - is_elementwise(*predicate, arena, cache) - && is_elementwise(*truthy, arena, cache) - && is_elementwise(*falsy, arena, cache) - }, - AExpr::AnonymousFunction { - input, - function: _, - output_type: _, - options, - } => { - options.is_elementwise() && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) - }, - AExpr::Function { - input, - function, - options, - } => { - match function { - // Non-strict strptime must be done in-memory to ensure the format - // is consistent across the entire dataframe. - #[cfg(feature = "strings")] - FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)) => opts.strict, - _ => { - options.is_elementwise() - && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) - }, - } - }, + expr_key = next_key; + } - AExpr::Window { .. } => false, - AExpr::Slice { .. } => false, - AExpr::Len => false, - }; + true + })(), + ); + } - cache.is_elementwise.insert(expr_key, ret); - ret + *cache.is_elementwise.get(&expr_key).unwrap() } #[recursive::recursive] @@ -403,7 +366,7 @@ fn lower_exprs_with_ctx( let mut transformed_exprs = Vec::with_capacity(exprs.len()); for expr in exprs.iter().copied() { - if is_elementwise(expr, ctx.expr_arena, ctx.cache) { + if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) { if !is_input_independent(expr, ctx) { input_nodes.insert(input); } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index db999481bf97..fcbec84a2e53 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -12,7 +12,9 @@ use polars_utils::itertools::Itertools; use slotmap::SlotMap; use super::{PhysNode, PhysNodeKey, PhysNodeKind}; -use crate::physical_plan::lower_expr::{build_select_node, is_elementwise, lower_exprs, ExprCache}; +use crate::physical_plan::lower_expr::{ + build_select_node, is_elementwise_rec_cached, lower_exprs, ExprCache, +}; fn build_slice_node( input: PhysNodeKey, @@ -79,7 +81,7 @@ pub fn lower_ir( IR::HStack { input, exprs, .. } if exprs .iter() - .all(|e| is_elementwise(e.node(), expr_arena, expr_cache)) => + .all(|e| is_elementwise_rec_cached(e.node(), expr_arena, expr_cache)) => { // FIXME: constant literal columns should be broadcasted with hstack. let selectors = exprs.clone(); @@ -189,7 +191,7 @@ pub fn lower_ir( } if let Some(predicate) = filter.clone() { - if !is_elementwise(predicate.node(), expr_arena, expr_cache) { + if !is_elementwise_rec_cached(predicate.node(), expr_arena, expr_cache) { todo!() } @@ -461,7 +463,7 @@ pub fn lower_ir( | IRAggExpr::Sum(input) | IRAggExpr::Var(input, ..) | IRAggExpr::Std(input, ..) => { - if is_elementwise(*input, expr_arena, expr_cache) { + if is_elementwise_rec_cached(*input, expr_arena, expr_cache) { input_exprs.push(ExprIR::from_node(*input, expr_arena)); } else { todo!() diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index 9516be3b902a..97e6a9c73272 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -4,16 +4,12 @@ use std::cmp::Reverse; use polars_core::prelude::*; use polars_core::POOL; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; -use polars_plan::plans::{Context, IRPlan, IsStreamableContext, IR}; +use polars_plan::plans::{Context, IRPlan, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::AExpr; use polars_utils::arena::{Arena, Node}; use slotmap::{SecondaryMap, SlotMap}; -fn is_streamable(node: Node, arena: &Arena) -> bool { - polars_plan::plans::is_streamable(node, arena, IsStreamableContext::new(Context::Default)) -} - pub fn run_query( node: Node, mut ir_arena: Arena, diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index 225c0b97553c..84aae68616a1 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -369,3 +369,22 @@ def test_streaming_with_hconcat(tmp_path: Path) -> None: ) assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_elementwise_identification_in_ternary_15767(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + ( + pl.LazyFrame({"a": pl.Series([1])}) + .with_columns(b=pl.col("a").is_in(pl.Series([1, 2, 3]))) + .sink_parquet(tmp_path / "1") + ) + + ( + pl.LazyFrame({"a": pl.Series([1])}) + .with_columns( + b=pl.when(pl.col("a").is_in(pl.Series([1, 2, 3]))).then(pl.col("a")) + ) + .sink_parquet(tmp_path / "1") + ) diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 7a458d65e7fe..ac9316616a76 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -5,6 +5,7 @@ import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal from polars.testing.asserts.series import assert_series_equal @@ -497,7 +498,7 @@ def test_predicate_push_down_with_alias_15442() -> None: assert output.to_dict(as_series=False) == {"a": [1]} -def test_predicate_push_down_list_gather_17492() -> None: +def test_predicate_slice_pushdown_list_gather_17492() -> None: lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]}) assert_frame_equal( @@ -512,6 +513,12 @@ def test_predicate_push_down_list_gather_17492() -> None: .explain() ) + # Also check slice pushdown + q = lf.with_columns(pl.col("val").list.get(1).alias("b")).slice(1, 1) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + q.collect() + def test_predicate_pushdown_struct_unnest_19632() -> None: lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a")