Skip to content

Commit

Permalink
fix issue where CTE could not be referenced more than 1 time
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewgapp committed Jan 9, 2024
1 parent 3096c1d commit d7721f1
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 32 deletions.
78 changes: 54 additions & 24 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;
use std::sync::atomic::AtomicI32;
use std::sync::{Arc, OnceLock};

use crate::datasource::file_format::arrow::ArrowFormat;
use crate::datasource::file_format::avro::AvroFormat;
Expand Down Expand Up @@ -444,11 +445,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner {
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
reset_recursive_cte_physical_plan_branch_number();

match self.handle_explain(logical_plan, session_state).await? {
Some(plan) => Ok(plan),
None => {
let plan = self
.create_initial_plan(logical_plan, session_state)
.create_initial_plan(logical_plan, session_state, None)
.await?;
self.optimize_internal(plan, session_state, |_, _| {})
}
Expand Down Expand Up @@ -479,6 +482,23 @@ impl PhysicalPlanner for DefaultPhysicalPlanner {
}
}

// atomic global incrmenter

static RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH: OnceLock<AtomicI32> = OnceLock::new();

fn new_recursive_cte_physical_plan_branch_number() -> u32 {
let counter = RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH
.get_or_init(|| AtomicI32::new(0))
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
counter as u32
}

fn reset_recursive_cte_physical_plan_branch_number() {
RECURSIVE_CTE_PHYSICAL_PLAN_BRANCH
.get_or_init(|| AtomicI32::new(0))
.store(0, std::sync::atomic::Ordering::SeqCst);
}

impl DefaultPhysicalPlanner {
/// Create a physical planner that uses `extension_planners` to
/// plan user-defined logical nodes [`LogicalPlan::Extension`].
Expand All @@ -499,6 +519,7 @@ impl DefaultPhysicalPlanner {
&'a self,
logical_plans: impl IntoIterator<Item = &'a LogicalPlan> + Send + 'a,
session_state: &'a SessionState,
ctx: Option<&'a String>,
) -> BoxFuture<'a, Result<Vec<Arc<dyn ExecutionPlan>>>> {
async move {
// First build futures with as little references as possible, then performing some stream magic.
Expand All @@ -511,7 +532,7 @@ impl DefaultPhysicalPlanner {
.into_iter()
.enumerate()
.map(|(idx, lp)| async move {
let plan = self.create_initial_plan(lp, session_state).await?;
let plan = self.create_initial_plan(lp, session_state, ctx).await?;
Ok((idx, plan)) as Result<_>
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -540,6 +561,7 @@ impl DefaultPhysicalPlanner {
&'a self,
logical_plan: &'a LogicalPlan,
session_state: &'a SessionState,
ctx: Option<&'a String>,
) -> BoxFuture<'a, Result<Arc<dyn ExecutionPlan>>> {
async move {
let exec_plan: Result<Arc<dyn ExecutionPlan>> = match logical_plan {
Expand All @@ -565,7 +587,7 @@ impl DefaultPhysicalPlanner {
single_file_output,
copy_options,
}) => {
let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;

// TODO: make this behavior configurable via options (should copy to create path/file as needed?)
// TODO: add additional configurable options for if existing files should be overwritten or
Expand Down Expand Up @@ -618,7 +640,7 @@ impl DefaultPhysicalPlanner {
let name = table_name.table();
let schema = session_state.schema_for_ref(table_name)?;
if let Some(provider) = schema.table(name).await {
let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;
provider.insert_into(session_state, input_exec, false).await
} else {
return exec_err!(
Expand All @@ -635,7 +657,7 @@ impl DefaultPhysicalPlanner {
let name = table_name.table();
let schema = session_state.schema_for_ref(table_name)?;
if let Some(provider) = schema.table(name).await {
let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;
provider.insert_into(session_state, input_exec, true).await
} else {
return exec_err!(
Expand Down Expand Up @@ -676,7 +698,7 @@ impl DefaultPhysicalPlanner {
);
}

let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;

// at this moment we are guaranteed by the logical planner
// to have all the window_expr to have equal sort key
Expand Down Expand Up @@ -772,7 +794,7 @@ impl DefaultPhysicalPlanner {
..
}) => {
// Initially need to perform the aggregate and then merge the partitions
let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;
let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();

Expand Down Expand Up @@ -849,7 +871,7 @@ impl DefaultPhysicalPlanner {
)?))
}
LogicalPlan::Projection(Projection { input, expr, .. }) => {
let input_exec = self.create_initial_plan(input, session_state).await?;
let input_exec = self.create_initial_plan(input, session_state, ctx).await?;
let input_schema = input.as_ref().schema();

let physical_exprs = expr
Expand Down Expand Up @@ -901,7 +923,7 @@ impl DefaultPhysicalPlanner {
)?))
}
LogicalPlan::Filter(filter) => {
let physical_input = self.create_initial_plan(&filter.input, session_state).await?;
let physical_input = self.create_initial_plan(&filter.input, session_state, ctx).await?;
let input_schema = physical_input.as_ref().schema();
let input_dfschema = filter.input.schema();

Expand All @@ -914,7 +936,7 @@ impl DefaultPhysicalPlanner {
Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?;
let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state, ctx).await?;

if schema.fields().len() < physical_plans[0].schema().fields().len() {
// `schema` could be a subset of the child schema. For example
Expand All @@ -929,7 +951,7 @@ impl DefaultPhysicalPlanner {
input,
partitioning_scheme,
}) => {
let physical_input = self.create_initial_plan(input, session_state).await?;
let physical_input = self.create_initial_plan(input, session_state, ctx).await?;
let input_schema = physical_input.schema();
let input_dfschema = input.as_ref().schema();
let physical_partitioning = match partitioning_scheme {
Expand Down Expand Up @@ -960,7 +982,7 @@ impl DefaultPhysicalPlanner {
)?))
}
LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => {
let physical_input = self.create_initial_plan(input, session_state).await?;
let physical_input = self.create_initial_plan(input, session_state, ctx).await?;
let input_schema = physical_input.as_ref().schema();
let input_dfschema = input.as_ref().schema();
let sort_expr = expr
Expand Down Expand Up @@ -1051,12 +1073,12 @@ impl DefaultPhysicalPlanner {
};

return self
.create_initial_plan(&join_plan, session_state)
.create_initial_plan(&join_plan, session_state, ctx)
.await;
}

// All equi-join keys are columns now, create physical join plan
let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?;
let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?;
let [physical_left, physical_right]: [Arc<dyn ExecutionPlan>; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?;
let left_df_schema = left.schema();
let right_df_schema = right.schema();
Expand Down Expand Up @@ -1191,7 +1213,7 @@ impl DefaultPhysicalPlanner {
}
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state).await?;
let left_right = self.create_initial_plan_multi([left.as_ref(), right.as_ref()], session_state, ctx).await?;
let [left, right]: [Arc<dyn ExecutionPlan>; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?;
Ok(Arc::new(CrossJoinExec::new(left, right)))
}
Expand All @@ -1204,10 +1226,10 @@ impl DefaultPhysicalPlanner {
SchemaRef::new(schema.as_ref().to_owned().into()),
))),
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
self.create_initial_plan(input, session_state).await
self.create_initial_plan(input, session_state, ctx).await
}
LogicalPlan::Limit(Limit { input, skip, fetch, .. }) => {
let input = self.create_initial_plan(input, session_state).await?;
let input = self.create_initial_plan(input, session_state, ctx).await?;

// GlobalLimitExec requires a single partition for input
let input = if input.output_partitioning().partition_count() == 1 {
Expand All @@ -1225,7 +1247,7 @@ impl DefaultPhysicalPlanner {
Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch)))
}
LogicalPlan::Unnest(Unnest { input, column, schema, options }) => {
let input = self.create_initial_plan(input, session_state).await?;
let input = self.create_initial_plan(input, session_state, ctx).await?;
let column_exec = schema.index_of_column(column)
.map(|idx| Column::new(&column.name, idx))?;
let schema = SchemaRef::new(schema.as_ref().to_owned().into());
Expand Down Expand Up @@ -1278,7 +1300,7 @@ impl DefaultPhysicalPlanner {
"Unsupported logical plan: Analyze must be root of the plan"
),
LogicalPlan::Extension(e) => {
let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state).await?;
let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state, ctx).await?;

let mut maybe_plan = None;
for planner in &self.extension_planners {
Expand Down Expand Up @@ -1314,13 +1336,19 @@ impl DefaultPhysicalPlanner {
Ok(plan)
}
}
// LogicalPlan::SubqueryAlias(SubqueryAlias())
LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct }) => {
let static_term = self.create_initial_plan(static_term, session_state).await?;
let recursive_term = self.create_initial_plan(recursive_term, session_state).await?;
let name = format!("{}-{}", name, new_recursive_cte_physical_plan_branch_number());

let ctx = Some(&name);

let static_term = self.create_initial_plan(static_term, session_state, ctx).await?;
let recursive_term = self.create_initial_plan(recursive_term, session_state, ctx).await?;

Ok(Arc::new(RecursiveQueryExec::new(name.clone(), static_term, recursive_term, *is_distinct)))
}
LogicalPlan::NamedRelation(NamedRelation {name, schema}) => {
LogicalPlan::NamedRelation(NamedRelation {schema, ..}) => {
let name = ctx.expect("NamedRelation must have a context that contains the recursive query's branch name");
// Named relations is how we represent access to any sort of dynamic data provider. They
// differ from tables in the sense that they can start existing dynamically during the
// execution of a query and then disappear before it even finishes.
Expand Down Expand Up @@ -1895,6 +1923,8 @@ impl DefaultPhysicalPlanner {
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
reset_recursive_cte_physical_plan_branch_number();

if let LogicalPlan::Explain(e) = logical_plan {
use PlanType::*;
let mut stringified_plans = vec![];
Expand All @@ -1910,7 +1940,7 @@ impl DefaultPhysicalPlanner {

if !config.logical_plan_only && e.logical_optimization_succeeded {
match self
.create_initial_plan(e.plan.as_ref(), session_state)
.create_initial_plan(e.plan.as_ref(), session_state, None)
.await
{
Ok(input) => {
Expand Down
18 changes: 10 additions & 8 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use datafusion_common::{
plan_err, sql_err, Constraints, DFSchema, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder,
logical_plan, CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan,
LogicalPlanBuilder,
};
use sqlparser::ast::{
Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator,
Expand Down Expand Up @@ -133,10 +134,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
static_metadata,
)?;

let name = cte_name.clone();

// Step 2.2: Create a temporary relation logical plan that will be used
// as the input to the recursive term
let named_relation = LogicalPlanBuilder::named_relation(
cte_name.as_str(),
&name,
Arc::new(named_relation_schema),
)
.build()?;
Expand All @@ -157,14 +160,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

// ---------- Step 4: Create the final plan ------------------
// Step 4.1: Compile the final plan
let final_plan = LogicalPlanBuilder::from(static_plan)
.to_recursive_query(
cte_name.clone(),
recursive_plan,
distinct,
)?
let logical_plan = LogicalPlanBuilder::from(static_plan)
.to_recursive_query(name, recursive_plan, distinct)?
.build()?;

let final_plan =
self.apply_table_alias(logical_plan, cte.alias)?;

// Step 4.2: Remove the temporary relation from the planning context and replace it
// with the final plan.
planner_context.insert_cte(cte_name.clone(), final_plan);
Expand Down
80 changes: 80 additions & 0 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,83 @@ WITH RECURSIVE nodes AS (
SELECT sum(id) FROM nodes
----
55

# setup
statement ok
CREATE TABLE t(a BIGINT) AS VALUES(1),(2),(3);

# referencing CTE multiple times does not error
query II rowsort
WITH RECURSIVE my_cte AS (
SELECT a from t
UNION ALL
SELECT a+2 as a
FROM my_cte
WHERE a<5
)
SELECT * FROM my_cte t1, my_cte
----
1 1
1 2
1 3
1 3
1 4
1 5
1 5
1 6
2 1
2 2
2 3
2 3
2 4
2 5
2 5
2 6
3 1
3 1
3 2
3 2
3 3
3 3
3 3
3 3
3 4
3 4
3 5
3 5
3 5
3 5
3 6
3 6
4 1
4 2
4 3
4 3
4 4
4 5
4 5
4 6
5 1
5 1
5 2
5 2
5 3
5 3
5 3
5 3
5 4
5 4
5 5
5 5
5 5
5 5
5 6
5 6
6 1
6 2
6 3
6 3
6 4
6 5
6 5
6 6

0 comments on commit d7721f1

Please sign in to comment.