From 9c49413cf798d8f826bb24e0ce63f6738f835f8c Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Oct 2024 08:04:36 -0500 Subject: [PATCH 1/4] Make PruningPredicate's rewrite public --- .../core/src/physical_optimizer/pruning.rs | 181 +++++++++++++++--- 1 file changed, 158 insertions(+), 23 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9bc2bb1d1db9..09c5a5a54d71 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -478,6 +478,37 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions +/// or predicates that reference columns that are not in the schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be handled by DataFusion's transformation rules + /// or is referencing a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl ConstantUnhandledPredicateHook { + fn new(default: Arc) -> Self { + Self { default } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + self.default.clone() + } +} + +fn default_unhandled_hook() -> Arc { + Arc::new(ConstantUnhandledPredicateHook::new(Arc::new( + phys_expr::Literal::new(ScalarValue::Boolean(Some(true))), + ))) +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -502,10 +533,16 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = default_unhandled_hook(); + // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -1316,23 +1353,43 @@ const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// /// Returns the pruning predicate as an [`PhysicalExpr`] /// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` +pub fn rewrite_predicate_to_statistics_predicate( + expr: &Arc, + schema: &Schema, + unhandled_hook: Option>, +) -> Arc { + let unhandled_hook = unhandled_hook.unwrap_or(default_unhandled_hook()); + + let mut required_columns = RequiredColumns::new(); + + build_predicate_expression(expr, schema, &mut required_columns, &unhandled_hook) +} + +/// Translate logical filter expression into pruning predicate +/// expression that will evaluate to FALSE if it can be determined no +/// rows between the min/max values could pass the predicates. +/// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// +/// Returns the pruning predicate as an [`PhysicalExpr`] +/// +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1341,19 +1398,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1382,9 +1439,14 @@ fn build_predicate_expression( }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } @@ -1396,13 +1458,15 @@ fn build_predicate_expression( bin_expr.right().clone(), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1410,7 +1474,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1423,12 +1487,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1582,6 +1645,8 @@ mod tests { use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; + use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] @@ -3397,6 +3462,75 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42)))) + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema_with_b = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewrite_predicate_to_statistics_predicate( + &expr, + &schema, + Some(Arc::new(CustomUnhandledHook {})), + ) + }; + + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(ScalarValue::Int32(Some(12)))); + let known_expression_transformed = rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + None, + ); + + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(ScalarValue::Int32(Some(12)))); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3886,6 +4020,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) + let unhandled_hook = default_unhandled_hook(); + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } From d249584215e5259e9fa2f3392abf82c4cf3579e6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:09:00 -0500 Subject: [PATCH 2/4] feedback --- .../core/src/physical_optimizer/pruning.rs | 87 ++++++++++++------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 09c5a5a54d71..46ab4eca821e 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1348,26 +1348,56 @@ fn build_is_null_column_expr( /// The maximum number of entries in an `InList` that might be rewritten into /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; +/// Rewrite a predicate expression to a pruning predicate expression. +/// See documentation for `PruningPredicate` for more information. +pub struct PredicateRewriter { + unhandled_hook: Arc, +} -/// Translate logical filter expression into pruning predicate -/// expression that will evaluate to FALSE if it can be determined no -/// rows between the min/max values could pass the predicates. -/// -/// Any predicates that can not be translated will be passed to `unhandled_hook`. -/// -/// Returns the pruning predicate as an [`PhysicalExpr`] -/// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` -pub fn rewrite_predicate_to_statistics_predicate( - expr: &Arc, - schema: &Schema, - unhandled_hook: Option>, -) -> Arc { - let unhandled_hook = unhandled_hook.unwrap_or(default_unhandled_hook()); +impl Default for PredicateRewriter { + fn default() -> Self { + Self { + unhandled_hook: default_unhandled_hook(), + } + } +} - let mut required_columns = RequiredColumns::new(); +impl PredicateRewriter { + /// Create a new `PredicateRewriter` + pub fn new() -> Self { + Self::default() + } + + /// Set the unhandled hook to be used when a predicate can not be rewritten + pub fn with_unhandled_hook( + self, + unhandled_hook: Arc, + ) -> Self { + Self { unhandled_hook } + } - build_predicate_expression(expr, schema, &mut required_columns, &unhandled_hook) + /// Translate logical filter expression into pruning predicate + /// expression that will evaluate to FALSE if it can be determined no + /// rows between the min/max values could pass the predicates. + /// + /// Any predicates that can not be translated will be passed to `unhandled_hook`. + /// + /// Returns the pruning predicate as an [`PhysicalExpr`] + /// + /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` + pub fn rewrite_predicate_to_statistics_predicate( + &self, + expr: &Arc, + schema: &Schema, + ) -> Arc { + let mut required_columns = RequiredColumns::new(); + build_predicate_expression( + expr, + schema, + &mut required_columns, + &self.unhandled_hook, + ) + } } /// Translate logical filter expression into pruning predicate @@ -3481,25 +3511,24 @@ mod tests { Field::new("b", DataType::Int32, true), ]); + let rewriter = PredicateRewriter::new() + .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); + let transform_expr = |expr| { let expr = logical2physical(&expr, &schema_with_b); - rewrite_predicate_to_statistics_predicate( - &expr, - &schema, - Some(Arc::new(CustomUnhandledHook {})), - ) + rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) }; // transform an arbitrary valid expression that we know is handled - let known_expression = col("a").eq(lit(ScalarValue::Int32(Some(12)))); - let known_expression_transformed = rewrite_predicate_to_statistics_predicate( - &logical2physical(&known_expression, &schema), - &schema, - None, - ); + let known_expression = col("a").eq(lit(12)); + let known_expression_transformed = PredicateRewriter::new() + .rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + ); // an expression referencing an unknown column (that is not in the schema) gets passed to the hook - let input = col("b").eq(lit(ScalarValue::Int32(Some(12)))); + let input = col("b").eq(lit(12)); let expected = logical2physical(&lit(42), &schema); let transformed = transform_expr(input.clone()); assert_eq!(transformed.to_string(), expected.to_string()); From e8dba08685ace5527eb49ea0ab419cb10a9b67a1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 11 Oct 2024 10:26:48 -0400 Subject: [PATCH 3/4] Improve documentation and add default to ConstantUnhandledPredicatehook --- .../core/src/physical_optimizer/pruning.rs | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 46ab4eca821e..798ce0bccf64 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -458,7 +458,7 @@ pub trait PruningStatistics { /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10 -///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 +/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -478,19 +478,30 @@ pub struct PruningPredicate { literal_guarantees: Vec, } -/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions -/// or predicates that reference columns that are not in the schema. +/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain +/// complex expressions or predicates that reference columns that are not in the +/// schema. pub trait UnhandledPredicateHook { - /// Called when a predicate can not be handled by DataFusion's transformation rules - /// or is referencing a column that is not in the schema. + /// Called when a predicate can not be rewritten in terms of statistics or + /// references a column that is not in the schema. fn handle(&self, expr: &Arc) -> Arc; } +/// The default handling for unhandled predicates is to return a constant `true` +/// (meaning don't prune the container) #[derive(Debug, Clone)] struct ConstantUnhandledPredicateHook { default: Arc, } +impl Default for ConstantUnhandledPredicateHook { + fn default() -> Self { + Self { + default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))), + } + } +} + impl ConstantUnhandledPredicateHook { fn new(default: Arc) -> Self { Self { default } @@ -503,12 +514,6 @@ impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { } } -fn default_unhandled_hook() -> Arc { - Arc::new(ConstantUnhandledPredicateHook::new(Arc::new( - phys_expr::Literal::new(ScalarValue::Boolean(Some(true))), - ))) -} - impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -533,7 +538,7 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { - let unhandled_hook = default_unhandled_hook(); + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; // build predicate expression once let mut required_columns = RequiredColumns::new(); @@ -1348,8 +1353,9 @@ fn build_is_null_column_expr( /// The maximum number of entries in an `InList` that might be rewritten into /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; -/// Rewrite a predicate expression to a pruning predicate expression. -/// See documentation for `PruningPredicate` for more information. + +/// Rewrite a predicate expression in terms of statistics (min/max/null_counts) +/// for use as a [`PruningPredicate`]. pub struct PredicateRewriter { unhandled_hook: Arc, } @@ -1357,7 +1363,7 @@ pub struct PredicateRewriter { impl Default for PredicateRewriter { fn default() -> Self { Self { - unhandled_hook: default_unhandled_hook(), + unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()), } } } @@ -4049,7 +4055,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - let unhandled_hook = default_unhandled_hook(); + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } From c438fd00c14fe70a93b747ba32bcf76e57b94c58 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:04:56 -0500 Subject: [PATCH 4/4] Update pruning.rs --- datafusion/core/src/physical_optimizer/pruning.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 798ce0bccf64..eb03b337779c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -502,12 +502,6 @@ impl Default for ConstantUnhandledPredicateHook { } } -impl ConstantUnhandledPredicateHook { - fn new(default: Arc) -> Self { - Self { default } - } -} - impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { fn handle(&self, _expr: &Arc) -> Arc { self.default.clone()