Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PruningPredicate's rewrite public #12850

Merged
merged 5 commits into from
Oct 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 188 additions & 24 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ pub trait PruningStatistics {
/// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741
/// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf
/// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10
///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
#[derive(Debug, Clone)]
pub struct PruningPredicate {
/// The input schema against which the predicate will be evaluated
Expand All @@ -478,6 +478,36 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain
/// complex expressions or predicates that reference columns that are not in the
/// schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be rewritten in terms of statistics or
/// references a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}

/// The default handling for unhandled predicates is to return a constant `true`
/// (meaning don't prune the container)
#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl Default for ConstantUnhandledPredicateHook {
fn default() -> Self {
Self {
default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))),
}
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +532,16 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _;

// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1312,27 +1348,78 @@ fn build_is_null_column_expr(
/// an OR chain
const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;

/// Rewrite a predicate expression in terms of statistics (min/max/null_counts)
/// for use as a [`PruningPredicate`].
pub struct PredicateRewriter {
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
}

impl Default for PredicateRewriter {
fn default() -> Self {
Self {
unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()),
}
}
}

impl PredicateRewriter {
/// Create a new `PredicateRewriter`
pub fn new() -> Self {
Self::default()
}

/// Set the unhandled hook to be used when a predicate can not be rewritten
pub fn with_unhandled_hook(
self,
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
) -> Self {
Self { unhandled_hook }
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
pub fn rewrite_predicate_to_statistics_predicate(
&self,
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Arc<dyn PhysicalExpr> {
let mut required_columns = RequiredColumns::new();
build_predicate_expression(
expr,
schema,
&mut required_columns,
&self.unhandled_hook,
)
}
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1428,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1469,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(
&change_expr,
schema,
required_columns,
unhandled_hook,
);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1488,23 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr =
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr =
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1423,12 +1517,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
}
Err(_) => return unhandled_hook.handle(expr),
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder)
.unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1582,6 +1675,8 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_functions_nested::expr_fn::{array_has, make_array};
use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -3397,6 +3492,74 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let schema_with_b = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let rewriter = PredicateRewriter::new()
.with_unhandled_hook(Arc::new(CustomUnhandledHook {}));

let transform_expr = |expr| {
let expr = logical2physical(&expr, &schema_with_b);
rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema)
};

// transform an arbitrary valid expression that we know is handled
let known_expression = col("a").eq(lit(12));
let known_expression_transformed = PredicateRewriter::new()
.rewrite_predicate_to_statistics_predicate(
&logical2physical(&known_expression, &schema),
&schema,
);

// an expression referencing an unknown column (that is not in the schema) gets passed to the hook
let input = col("b").eq(lit(12));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown column
let input = known_expression.clone().and(input.clone());
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// an unknown expression gets passed to the hook
let input = array_has(make_array(vec![lit(1)]), col("a"));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown expression
let input = known_expression.and(input);
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4049,7 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _;
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}