Skip to content

Commit

Permalink
[FEAT] Support for correlated subqueries in SQL (not yet executable) (#…
Browse files Browse the repository at this point in the history
…3304)

This PR adds support for converting SQL queries with correlated
subqueries into LogicalPlans. It does not add the ability to execute
queries with correlated subqueries, but if I am correct, this is the
last large piece of support we need on the SQL side for TPC-H questions,
and most of the remaining work is plan rewriting, optimization, and
translation.

I believe with the new `alias_map` value in `SQLPlanner`, we can
actually simplify a lot of the logic in `plan_aggregate_query` and
`plan_non_agg_query` but I will not attempt to do that in this PR.

Relevant for TPC-H questions 4, 17, 20, 21, 22.

Todo:
- [x] tests
  • Loading branch information
kevinzwang authored Nov 19, 2024
1 parent dab006f commit 25304eb
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 160 deletions.
40 changes: 35 additions & 5 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,15 @@ pub enum Expr {
#[display("{_0}")]
ScalarFunction(ScalarFunction),

#[display("{_0}")]
#[display("subquery {_0}")]
Subquery(Subquery),
#[display("{_0}, {_1}")]
#[display("{_0} in {_1}")]
InSubquery(ExprRef, Subquery),
#[display("{_0}")]
#[display("exists {_0}")]
Exists(Subquery),

#[display("{_0}")]
OuterReferenceColumn(OuterReferenceColumn),
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
Expand All @@ -164,6 +167,20 @@ pub struct ApproxPercentileParams {
pub force_list_output: bool,
}

/// Reference to a qualified field in a parent query, used for correlated subqueries.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
pub struct OuterReferenceColumn {
pub field: Field,
/// The parent query that the column refers to, with depth=1 denoting the direct parent.
pub depth: u64,
}

impl Display for OuterReferenceColumn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "outer_col({}, {})", self.field.name, self.depth)
}
}

#[derive(Display, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum AggExpr {
#[display("count({_0}, {_1})")]
Expand Down Expand Up @@ -717,6 +734,11 @@ impl Expr {
Self::Subquery(..) | Self::InSubquery(..) | Self::Exists(..) => {
FieldID::new("__subquery__")
} // todo: better/unique id
Self::OuterReferenceColumn(c) => {
let name = &c.field.name;
let depth = c.depth;
FieldID::new(format!("outer_col({name}, {depth})"))
}
}
}

Expand All @@ -727,6 +749,7 @@ impl Expr {
Self::Literal(..) => vec![],
Self::Subquery(..) => vec![],
Self::Exists(..) => vec![],
Self::OuterReferenceColumn(..) => vec![],

// One child.
Self::Not(expr)
Expand Down Expand Up @@ -763,7 +786,11 @@ impl Expr {
pub fn with_new_children(&self, children: Vec<ExprRef>) -> Self {
match self {
// no children
Self::Column(..) | Self::Literal(..) | Self::Subquery(..) | Self::Exists(..) => {
Self::Column(..)
| Self::Literal(..)
| Self::Subquery(..)
| Self::Exists(..)
| Self::OuterReferenceColumn(..) => {
assert!(children.is_empty(), "Should have no children");
self.clone()
}
Expand Down Expand Up @@ -1027,6 +1054,7 @@ impl Expr {
}
Self::InSubquery(expr, _) => Ok(Field::new(expr.name(), DataType::Boolean)),
Self::Exists(_) => Ok(Field::new("exists", DataType::Boolean)),
Self::OuterReferenceColumn(c) => Ok(c.field.clone()),
}
}

Expand Down Expand Up @@ -1060,6 +1088,7 @@ impl Expr {
Self::Subquery(subquery) => subquery.name(),
Self::InSubquery(expr, _) => expr.name(),
Self::Exists(subquery) => subquery.name(),
Self::OuterReferenceColumn(c) => &c.field.name,
}
}

Expand Down Expand Up @@ -1135,7 +1164,8 @@ impl Expr {
| Expr::ScalarFunction { .. }
| Expr::Subquery(..)
| Expr::InSubquery(..)
| Expr::Exists(..) => Err(io::Error::new(
| Expr::Exists(..)
| Expr::OuterReferenceColumn(..) => Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported expression for SQL translation",
)),
Expand Down
3 changes: 2 additions & 1 deletion src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ mod treenode;
pub use common_treenode;
pub use expr::{
binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, Subquery, SubqueryPlan,
ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType, Subquery,
SubqueryPlan,
};
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn requires_computation(e: &Expr) -> bool {
// Returns whether or not this expression runs any computation on the underlying data
match e {
Expr::Alias(child, _) => requires_computation(child),
Expr::Column(..) | Expr::Literal(_) => false,
Expr::Column(..) | Expr::Literal(_) | Expr::OuterReferenceColumn { .. } => false,
Expr::Agg(..)
| Expr::BinaryOp { .. }
| Expr::Cast(..)
Expand Down
8 changes: 5 additions & 3 deletions src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ fn replace_column_with_semantic_id(
Transformed::yes(new_expr.into())
} else {
match e.as_ref() {
Expr::Column(_) | Expr::Literal(_) | Expr::Subquery(_) | Expr::Exists(_) => {
Transformed::no(e)
}
Expr::Column(_)
| Expr::Literal(_)
| Expr::Subquery(_)
| Expr::Exists(_)
| Expr::OuterReferenceColumn { .. } => Transformed::no(e),
Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr(
agg_expr.clone(),
subexprs_to_replace,
Expand Down
4 changes: 2 additions & 2 deletions src/daft-logical-plan/src/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ fn translate_clustering_spec_expr(

Ok(expr.in_subquery(subquery.clone()))
}
// Cannot have agg exprs in clustering specs.
Expr::Agg(_) => Err(()),
// Cannot have agg exprs or references to other tables in clustering specs.
Expr::Agg(_) | Expr::OuterReferenceColumn { .. } => Err(()),
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/daft-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ impl Schema {
}
}

pub fn has_field(&self, name: &str) -> bool {
self.fields.contains_key(name)
}

pub fn get_index(&self, name: &str) -> DaftResult<usize> {
match self.fields.get_index_of(name) {
None => Err(DaftError::FieldNotFound(format!(
Expand Down
2 changes: 1 addition & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl Default for SQLFunctions {
}
}

impl SQLPlanner {
impl<'a> SQLPlanner<'a> {
pub(crate) fn plan_function(&self, func: &Function) -> SQLPlannerResult<ExprRef> {
// assert using only supported features
check_features(func)?;
Expand Down
65 changes: 63 additions & 2 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![feature(let_chains)]

pub mod catalog;
pub mod error;
pub mod functions;
Expand Down Expand Up @@ -25,7 +27,7 @@ mod tests {

use catalog::SQLCatalog;
use daft_core::prelude::*;
use daft_dsl::{col, lit};
use daft_dsl::{col, lit, Expr, OuterReferenceColumn, Subquery};
use daft_logical_plan::{
logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan,
LogicalPlanBuilder, LogicalPlanRef, SourceInfo,
Expand Down Expand Up @@ -106,7 +108,7 @@ mod tests {
}

#[fixture]
fn planner() -> SQLPlanner {
fn planner() -> SQLPlanner<'static> {
let mut catalog = SQLCatalog::new();

catalog.register_table("tbl1", tbl_1());
Expand Down Expand Up @@ -395,4 +397,63 @@ mod tests {

Ok(())
}

#[rstest]
#[case::with_second_select("select i32 as a from tbl1 where i32 > 0")]
#[case::with_where("select i32 as a from tbl1 where i32 > 0")]
#[case::with_where_aliased("select i32 as a from tbl1 where a > 0")]
#[case::with_groupby("select i32 as a from tbl1 group by i32")]
#[case::with_groupby_aliased("select i32 as a from tbl1 group by a")]
#[case::with_orderby("select i32 as a from tbl1 order by i32")]
#[case::with_orderby_aliased("select i32 as a from tbl1 order by a")]
#[case::with_many("select i32 as a from tbl1 where i32 > 0 group by i32 order by i32")]
#[case::with_many_aliased("select i32 as a from tbl1 where a > 0 group by a order by a")]
#[case::second_select("select i32 as a, a + 1 from tbl1")]
fn test_compiles_select_alias(
mut planner: SQLPlanner,
#[case] query: &str,
) -> SQLPlannerResult<()> {
let plan = planner.plan_sql(query);
if let Err(e) = plan {
panic!("query: {query}\nerror: {e:?}");
}
assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}");

Ok(())
}

#[rstest]
#[case::basic("select utf8 from tbl1 where i64 > (select max(id) from tbl2 where id = i32)")]
#[case::compound(
"select utf8 from tbl1 where i64 > (select max(id) from tbl2 where id = tbl1.i32)"
)]
fn test_correlated_subquery(
mut planner: SQLPlanner,
#[case] query: &str,
tbl_1: LogicalPlanRef,
tbl_2: LogicalPlanRef,
) -> SQLPlannerResult<()> {
let plan = planner.plan_sql(query)?;

let outer_col = Arc::new(Expr::OuterReferenceColumn(OuterReferenceColumn {
field: Field::new("i32", DataType::Int32),
depth: 1,
}));
let subquery = LogicalPlanBuilder::new(tbl_2, None)
.filter(col("id").eq(outer_col))?
.aggregate(vec![col("id").max()], vec![])?
.select(vec![col("id")])?
.build();

let subquery = Arc::new(Expr::Subquery(Subquery { plan: subquery }));

let expected = LogicalPlanBuilder::new(tbl_1, None)
.filter(col("i64").gt(subquery))?
.select(vec![col("utf8")])?
.build();

assert_eq!(plan, expected);

Ok(())
}
}
Loading

0 comments on commit 25304eb

Please sign in to comment.