diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5663e736dbd8..db5913503aaa 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" +checksum = "a116f46a969224200a0a97f29cfd4c50e7534e4b4826bd23ea2c3c533039c82c" dependencies = [ "bzip2", "flate2", @@ -733,9 +733,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "blake2" @@ -1125,7 +1125,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "num-traits", "num_cpus", @@ -1235,7 +1235,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.3", - "itertools 0.12.0", + "itertools", "log", "regex-syntax", ] @@ -1260,7 +1260,7 @@ dependencies = [ "hashbrown 0.14.3", "hex", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "md-5", "paste", @@ -1291,7 +1291,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "indexmap 2.1.0", - "itertools 0.12.0", + "itertools", "log", "once_cell", "parking_lot", @@ -1652,9 +1652,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.23" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b553656127a00601c8ae5590fcfdc118e4083a7924b6cf4ffc1ea4b99dc429d7" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" dependencies = [ "bytes", "fnv", @@ -1722,9 +1722,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" [[package]] name = "hex" @@ -1908,15 +1908,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.0" @@ -2072,16 +2063,16 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "libc", "redox_syscall", ] [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -2279,7 +2270,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.3", + "hermit-abi 0.3.4", "libc", ] @@ -2305,7 +2296,7 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools 0.12.0", + "itertools", "parking_lot", "percent-encoding", "quick-xml", @@ -2516,9 +2507,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "powerfmt" @@ -2534,14 +2525,13 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "predicates" -version = "3.0.4" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dfc28575c2e3f19cb3c73b93af36460ae898d426eba6fc15b9bd2a5220758a0" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" dependencies = [ "anstyle", "difflib", "float-cmp", - "itertools 0.11.0", "normalize-line-endings", "predicates-core", "regex", @@ -2836,11 +2826,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.29" +version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a1a81a2478639a14e68937903356dbac62cf52171148924f754bb8a8cd7a96c" +checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -3102,9 +3092,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" [[package]] name = "snafu" @@ -3563,9 +3553,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs new file mode 100644 index 000000000000..de13e73e003b --- /dev/null +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CteWorkTable implementation used for recursive queries + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion_common::not_impl_err; + +use crate::{ + error::Result, + logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown}, + physical_plan::ExecutionPlan, +}; + +use datafusion_common::DataFusionError; + +use crate::datasource::{TableProvider, TableType}; +use crate::execution::context::SessionState; + +/// The temporary working table where the previous iteration of a recursive query is stored +/// Naming is based on PostgreSQL's implementation. +/// See here for more details: www.postgresql.org/docs/11/queries-with.html#id-1.5.6.12.5.4 +pub struct CteWorkTable { + /// The name of the CTE work table + // WIP, see https://github.com/apache/arrow-datafusion/issues/462 + #[allow(dead_code)] + name: String, + /// This schema must be shared across both the static and recursive terms of a recursive query + table_schema: SchemaRef, +} + +impl CteWorkTable { + /// construct a new CteWorkTable with the given name and schema + /// This schema must match the schema of the recursive term of the query + /// Since the scan method will contain an physical plan that assumes this schema + pub fn new(name: &str, table_schema: SchemaRef) -> Self { + Self { + name: name.to_owned(), + table_schema, + } + } +} + +#[async_trait] +impl TableProvider for CteWorkTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_logical_plan(&self) -> Option<&LogicalPlan> { + None + } + + fn schema(&self) -> SchemaRef { + self.table_schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } + + async fn scan( + &self, + _state: &SessionState, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + not_impl_err!("scan not implemented for CteWorkTable yet") + } + + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> Result { + // TODO: should we support filter pushdown? + Ok(TableProviderFilterPushDown::Unsupported) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 2e516cc36a01..8f20da183a93 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -20,6 +20,7 @@ //! [`ListingTable`]: crate::datasource::listing::ListingTable pub mod avro_to_arrow; +pub mod cte_worktable; pub mod default_table_source; pub mod empty; pub mod file_format; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 1e378541b624..9b623d7a51ec 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,6 +26,7 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + cte_worktable::CteWorkTable, function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, provider::TableProviderFactory, @@ -1899,6 +1900,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { Ok(provider_as_source(provider)) } + /// Create a new CTE work table for a recursive CTE logical plan + /// This table will be used in conjunction with a Worktable physical plan + /// to read and write each iteration of a recursive CTE + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> Result> { + let table = Arc::new(CteWorkTable::new(name, schema)); + Ok(provider_as_source(table)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 98390ac271d0..bc448fe06fcf 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -87,8 +87,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1290,6 +1290,9 @@ impl DefaultPhysicalPlanner { Ok(plan) } } + LogicalPlan::RecursiveQuery(RecursiveQuery { name: _, static_term: _, recursive_term: _, is_distinct: _,.. }) => { + not_impl_err!("Physical counterpart of RecursiveQuery is not implemented yet") + } }; exec_plan }.boxed() diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 847fbbbf61c7..eb5e5bd42634 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -55,6 +55,8 @@ use datafusion_common::{ ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; +use super::plan::RecursiveQuery; + /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -121,6 +123,28 @@ impl LogicalPlanBuilder { })) } + /// Convert a regular plan into a recursive query. + /// `is_distinct` indicates whether the recursive term should be de-duplicated (`UNION`) after each iteration or not (`UNION ALL`). + pub fn to_recursive_query( + &self, + name: String, + recursive_term: LogicalPlan, + is_distinct: bool, + ) -> Result { + // TODO: we need to do a bunch of validation here. Maybe more. + if is_distinct { + return Err(DataFusionError::NotImplemented( + "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported".to_string(), + )); + } + Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term: Arc::new(self.plan.clone()), + recursive_term: Arc::new(recursive_term), + is_distinct, + }))) + } + /// Create a values list based relation, and the schema is inferred from data, consuming /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index bc722dd69ace..f6e6000897a5 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use plan::{ projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, - Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, - ToStringifiedPlan, Union, Unnest, Values, Window, + RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, + TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 93a38fb40df5..5ab8a9c99cd0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -154,6 +154,8 @@ pub enum LogicalPlan { /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` Unnest(Unnest), + /// A variadic query (e.g. "Recursive CTEs") + RecursiveQuery(RecursiveQuery), } impl LogicalPlan { @@ -191,6 +193,10 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // we take the schema of the static term as the schema of the entire recursive query + static_term.schema() + } } } @@ -243,6 +249,10 @@ impl LogicalPlan { | LogicalPlan::TableScan(_) => { vec![self.schema()] } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // return only the schema of the static term + static_term.all_schemas() + } // return children schemas LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) @@ -384,6 +394,7 @@ impl LogicalPlan { .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) @@ -430,6 +441,11 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => vec![static_term, recursive_term], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::Statement { .. } @@ -510,6 +526,9 @@ impl LogicalPlan { cross.left.head_output_expr() } } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + static_term.head_output_expr() + } LogicalPlan::Union(union) => Ok(Some(Expr::Column( union.schema.fields()[0].qualified_column(), ))), @@ -835,6 +854,14 @@ impl LogicalPlan { }; Ok(LogicalPlan::Distinct(distinct)) } + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, is_distinct, .. + }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(inputs[0].clone()), + recursive_term: Arc::new(inputs[1].clone()), + is_distinct: *is_distinct, + })), LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -1073,6 +1100,7 @@ impl LogicalPlan { }), LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), + LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, @@ -1408,6 +1436,11 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::RecursiveQuery(RecursiveQuery { + is_distinct, .. + }) => { + write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values .iter() @@ -1718,6 +1751,42 @@ pub struct EmptyRelation { pub schema: DFSchemaRef, } +/// A variadic query operation, Recursive CTE. +/// +/// # Recursive Query Evaluation +/// +/// From the [Postgres Docs]: +/// +/// 1. Evaluate the non-recursive term. For `UNION` (but not `UNION ALL`), +/// discard duplicate rows. Include all remaining rows in the result of the +/// recursive query, and also place them in a temporary working table. +// +/// 2. So long as the working table is not empty, repeat these steps: +/// +/// * Evaluate the recursive term, substituting the current contents of the +/// working table for the recursive self-reference. For `UNION` (but not `UNION +/// ALL`), discard duplicate rows and rows that duplicate any previous result +/// row. Include all remaining rows in the result of the recursive query, and +/// also place them in a temporary intermediate table. +/// +/// * Replace the contents of the working table with the contents of the +/// intermediate table, then empty the intermediate table. +/// +/// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct RecursiveQuery { + /// Name of the query + pub name: String, + /// The static term (initial contents of the working table) + pub static_term: Arc, + /// The recursive term (evaluated on the contents of the working table until + /// it returns an empty set) + pub recursive_term: Arc, + /// Should the output of the recursive term be deduplicated (`UNION`) or + /// not (`UNION ALL`). + pub is_distinct: bool, +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fc867df23c36..f29c7406acc9 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -365,6 +365,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan utils::optimize_children(self, plan, config)? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index d9c45510972c..ab0cb0a26551 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -163,6 +163,7 @@ fn optimize_projections( .collect::>() } LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f10f11c1c093..d95d69780301 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1734,6 +1734,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), + LogicalPlan::RecursiveQuery(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for RecursiveQuery", + )), } } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a04df5589b85..d4dd42edcd39 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -61,6 +61,19 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 388377e3ee6b..af0b91ae6c7e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::datatypes::Schema; use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; @@ -26,7 +27,8 @@ use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, + SetQuantifier, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -52,21 +54,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let set_expr = query.body; if let Some(with) = query.with { // Process CTEs from top to bottom - // do not allow self-references - if with.recursive { - if self - .context_provider - .options() - .execution - .enable_recursive_ctes - { - return plan_err!( - "Recursive CTEs are enabled but are not yet supported" - ); - } else { - return not_impl_err!("Recursive CTEs are not supported"); - } - } + + let is_recursive = with.recursive; for cte in with.cte_tables { // A `WITH` block can't use the same name more than once @@ -76,16 +65,127 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "WITH query name {cte_name:?} specified more than once" ))); } - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + if is_recursive { + if !self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return not_impl_err!("Recursive CTEs are not enabled"); + } + + match *cte.query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => { + let distinct = set_quantifier != SetQuantifier::All; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = self + .set_expr_to_plan(*left, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. + + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a table source for the temporary relation + let work_table_source = + self.context_provider.create_cte_work_table( + &cte_name, + Arc::new(Schema::from(static_plan.schema().as_ref())), + )?; - planner_context.insert_cte(cte_name, logical_plan); + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + work_table_source, + None, + )? + .build()?; + + let name = cte_name.clone(); + + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), work_table_plan); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = self + .set_expr_to_plan(*right, &mut planner_context.clone())?; + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let logical_plan = LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build()?; + + let final_plan = + self.apply_table_alias(logical_plan, cte.alias)?; + + // Step 4.2: Remove the temporary relation from the planning context and replace it + // with the final plan. + planner_context.insert_cte(cte_name.clone(), final_plan); + } + _ => { + return Err(DataFusionError::SQL( + ParserError(format!("Unsupported CTE: {cte}")), + None, + )); + } + }; + } else { + // create logical plan & pass backreferencing CTEs + // CTE expr don't need extend outer_query_schema + let logical_plan = + self.query_to_plan(*cte.query, &mut planner_context.clone())?; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + + planner_context.insert_cte(cte_name, logical_plan); + } } } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 44da4cd4d836..c88e2d1130ed 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1394,11 +1394,46 @@ fn recursive_ctes() { select * from numbers;"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "This feature is not implemented: Recursive CTEs are not supported", + "This feature is not implemented: Recursive CTEs are not enabled", err.strip_backtrace() ); } +#[test] +fn recursive_ctes_enabled() { + let sql = " + WITH RECURSIVE numbers AS ( + select 1 as n + UNION ALL + select n + 1 FROM numbers WHERE N < 10 + ) + select * from numbers;"; + + // manually setting up test here so that we can enable recursive ctes + let mut context = MockContextProvider::default(); + context.options_mut().execution.enable_recursive_ctes = true; + + let planner = SqlToRel::new_with_options(&context, ParserOptions::default()); + let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {}); + let mut ast = result.unwrap(); + + let plan = planner + .statement_to_plan(ast.pop_front().unwrap()) + .expect("recursive cte plan creation failed"); + + assert_eq!( + format!("{plan:?}"), + "Projection: numbers.n\ + \n SubqueryAlias: numbers\ + \n RecursiveQuery: is_distinct=false\ + \n Projection: Int64(1) AS n\ + \n EmptyRelation\ + \n Projection: numbers.n + Int64(1)\ + \n Filter: numbers.n < Int64(10)\ + \n TableScan: numbers" + ); +} + #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -2692,6 +2727,12 @@ struct MockContextProvider { udafs: HashMap>, } +impl MockContextProvider { + fn options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } +} + impl ContextProvider for MockContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let schema = match name.table() { @@ -2801,6 +2842,14 @@ impl ContextProvider for MockContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn create_cte_work_table( + &self, + _name: &str, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(EmptyTable::new(schema))) + } } #[test]