Skip to content

Commit

Permalink
Detect when filters make subqueries scalar (apache#8312)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and appletreeisyellow committed Dec 15, 2023
1 parent 18f8149 commit 304d7ae
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 3 deletions.
8 changes: 8 additions & 0 deletions datafusion/common/src/functional_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ impl FunctionalDependencies {
}
}

impl Deref for FunctionalDependencies {
type Target = [FunctionalDependence];

fn deref(&self) -> &Self::Target {
self.deps.as_slice()
}
}

/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression.
pub fn aggregate_functional_dependencies(
aggr_input_schema: &DFSchema,
Expand Down
141 changes: 138 additions & 3 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::logical_plan::{DmlStatement, Statement};
use crate::utils::{
enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre,
split_conjunction,
};
use crate::{
build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr,
Expand All @@ -47,7 +48,7 @@ use datafusion_common::tree_node::{
};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies,
OwnedTableReference, ParamValues, Result, UnnestOptions,
};
// backwards compatibility
Expand Down Expand Up @@ -1032,7 +1033,13 @@ impl LogicalPlan {
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
LogicalPlan::Filter(filter) => {
if filter.is_scalar() {
Some(1)
} else {
filter.input.max_rows()
}
}
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
LogicalPlan::Aggregate(Aggregate {
input, group_expr, ..
Expand Down Expand Up @@ -1913,6 +1920,73 @@ impl Filter {

Ok(Self { predicate, input })
}

/// Is this filter guaranteed to return 0 or 1 row in a given instantiation?
///
/// This function will return `true` if its predicate contains a conjunction of
/// `col(a) = <expr>`, where its schema has a unique filter that is covered
/// by this conjunction.
///
/// For example, for the table:
/// ```sql
/// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER);
/// ```
/// `Filter(a = 2).is_scalar() == true`
/// , whereas
/// `Filter(b = 2).is_scalar() == false`
/// and
/// `Filter(a = 2 OR b = 2).is_scalar() == false`
fn is_scalar(&self) -> bool {
let schema = self.input.schema();

let functional_dependencies = self.input.schema().functional_dependencies();
let unique_keys = functional_dependencies.iter().filter(|dep| {
let nullable = dep.nullable
&& dep
.source_indices
.iter()
.any(|&source| schema.field(source).is_nullable());
!nullable
&& dep.mode == Dependency::Single
&& dep.target_indices.len() == schema.fields().len()
});

let exprs = split_conjunction(&self.predicate);
let eq_pred_cols: HashSet<_> = exprs
.iter()
.filter_map(|expr| {
let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = expr
else {
return None;
};
// This is a no-op filter expression
if left == right {
return None;
}

match (left.as_ref(), right.as_ref()) {
(Expr::Column(_), Expr::Column(_)) => None,
(Expr::Column(c), _) | (_, Expr::Column(c)) => {
Some(schema.index_of_column(c).unwrap())
}
_ => None,
}
})
.collect();

// If we have a functional dependence that is a subset of our predicate,
// this filter is scalar
for key in unique_keys {
if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) {
return true;
}
}
false
}
}

/// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
Expand Down Expand Up @@ -2554,12 +2628,16 @@ pub struct Unnest {
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference};
use datafusion_common::{
not_impl_err, Constraint, DFSchema, ScalarValue, TableReference,
};
use std::collections::HashMap;
use std::sync::Arc;

fn employee_schema() -> Schema {
Schema::new(vec![
Expand Down Expand Up @@ -3056,6 +3134,63 @@ digraph {
.is_nullable());
}

#[test]
fn test_filter_is_scalar() {
// test empty placeholder
let schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

let source = Arc::new(LogicalTableSource::new(schema));
let schema = Arc::new(
DFSchema::try_from_qualified_schema(
TableReference::bare("tab"),
&source.schema(),
)
.unwrap(),
);
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source: source.clone(),
projection: None,
projected_schema: schema.clone(),
filters: vec![],
fetch: None,
}));
let col = schema.field(0).qualified_column();

let filter = Filter::try_new(
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
scan,
)
.unwrap();
assert!(!filter.is_scalar());
let unique_schema =
Arc::new(schema.as_ref().clone().with_functional_dependencies(
FunctionalDependencies::new_from_constraints(
Some(&Constraints::new_unverified(vec![Constraint::Unique(
vec![0],
)])),
1,
),
));
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::bare("tab"),
source,
projection: None,
projected_schema: unique_schema.clone(),
filters: vec![],
fetch: None,
}));
let col = schema.field(0).qualified_column();

let filter = Filter::try_new(
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
scan,
)
.unwrap();
assert!(filter.is_scalar());
}

#[test]
fn test_transform_explain() {
let schema = Schema::new(vec![
Expand Down
18 changes: 18 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(44, 'x', 3),
(55, 'w', 3);

statement ok
CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES
(11, 'e', 3),
(22, 'f', 1),
(44, 'g', 3),
(55, 'h', 3);

statement ok
CREATE EXTERNAL TABLE IF NOT EXISTS customer (
c_custkey BIGINT,
Expand Down Expand Up @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1

#non_aggregated_correlated_scalar_subquery_unique
query II rowsort
SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1
----
11 3
22 1
33 NULL
44 3


#non_aggregated_correlated_scalar_subquery
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1

Expand Down

0 comments on commit 304d7ae

Please sign in to comment.