From 16db847130bda06d3fc3d8b42cd303126c4d7995 Mon Sep 17 00:00:00 2001 From: Jesse Bakker Date: Wed, 22 Nov 2023 18:22:18 +0100 Subject: [PATCH] Detect when filters make subqueries scalar --- .../common/src/functional_dependencies.rs | 8 + datafusion/expr/src/logical_plan/plan.rs | 138 +++++++++++++++++- .../sqllogictest/test_files/subquery.slt | 18 +++ 3 files changed, 161 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index fbddcddab4bc..4587677e7726 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -413,6 +413,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea7a48d2c4f4..980719959da2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -33,6 +33,7 @@ use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -47,7 +48,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, OwnedTableReference, Result, ScalarValue, UnnestOptions, }; // backwards compatibility @@ -1033,7 +1034,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1935,6 +1942,73 @@ impl Filter { Ok(Self { predicate, input }) } + + /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? + /// + /// This function will return `true` if its predicate contains a conjunction of + /// `col(a) = `, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -2576,12 +2650,14 @@ pub struct Unnest { #[cfg(test)] mod tests { use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; + use datafusion_common::{not_impl_err, Constraint, DFSchema, TableReference}; use std::collections::HashMap; + use std::sync::Arc; fn employee_schema() -> Schema { Schema::new(vec![ @@ -3076,4 +3152,60 @@ digraph { .unwrap() .is_nullable()); } + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = + Arc::new(schema.as_ref().clone().with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + )); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } } diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 430e676fa477..3e0fcb7aa96e 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1