Skip to content

Commit

Permalink
[SPARK-43199][SQL] Make InlineCTE idempotent
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes `InlineCTE`'s idempotence. E.g. the following query:
```
WITH
  x(r) AS (SELECT random()),
  y(r) AS (SELECT * FROM x),
  z(r) AS (SELECT * FROM x)
SELECT * FROM z
```
currently breaks it because we take into account the reference to `x` from `y` when deciding about not inlining `x` in the first round:
```
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.InlineCTE ===
 WithCTE                                                        WithCTE
 :- CTERelationDef 0, false                                     :- CTERelationDef 0, false
 :  +- Project [rand()#218 AS r#219]                            :  +- Project [rand()#218 AS r#219]
 :     +- Project [random(2957388522017368375) AS rand()#218]   :     +- Project [random(2957388522017368375) AS rand()#218]
 :        +- OneRowRelation                                     :        +- OneRowRelation
!:- CTERelationDef 1, false                                     +- Project [r#222]
!:  +- Project [r#219 AS r#221]                                    +- Project [r#220 AS r#222]
!:     +- Project [r#219]                                             +- Project [r#220]
!:        +- CTERelationRef 0, true, [r#219]                             +- CTERelationRef 0, true, [r#220]
!:- CTERelationDef 2, false
!:  +- Project [r#220 AS r#222]
!:     +- Project [r#220]
!:        +- CTERelationRef 0, true, [r#220]
!+- Project [r#222]
!   +- CTERelationRef 2, true, [r#222]
```
But in the next round we inline `x` because `y` was removed due to lack of references:
```
Once strategy's idempotence is broken for batch Inline CTE
!WithCTE                                                        Project [r#222]
!:- CTERelationDef 0, false                                     +- Project [r#220 AS r#222]
!:  +- Project [rand()#218 AS r#219]                               +- Project [r#220]
!:     +- Project [random(2957388522017368375) AS rand()#218]         +- Project [r#225 AS r#220]
!:        +- OneRowRelation                                              +- Project [rand()#218 AS r#225]
!+- Project [r#222]                                                         +- Project [random(2957388522017368375) AS rand()#218]
!   +- Project [r#220 AS r#222]                                                +- OneRowRelation
!      +- Project [r#220]
!         +- CTERelationRef 0, true, [r#220]
```

### Why are the changes needed?
We use `InlineCTE` as an idempotent rule in the `Optimizer`, `CheckAnalysis` and `ProgressReporter`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added new UT.

Closes #40856 from peter-toth/SPARK-43199-make-inlinecte-idempotent.

Authored-by: Peter Toth <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
peter-toth authored and cloud-fan committed Apr 26, 2023
1 parent 9c19528 commit 8970415
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB

def checkAnalysis(plan: LogicalPlan): Unit = {
val inlineCTE = InlineCTE(alwaysInline = true)
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
inlineCTE.buildCTEMap(plan, cteMap)
cteMap.values.foreach { case (relation, refCount) =>
cteMap.values.foreach { case (relation, refCount, _) =>
// If a CTE relation is never used, it will disappear after inline. Here we explicitly check
// analysis for it, to make sure the entire query plan is valid.
if (refCount == 0) checkAnalysis0(relation.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
buildCTEMap(plan, cteMap)
cleanCTEMap(cteMap)
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
val inlined = inlineCTE(plan, cteMap, notInlined)
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
Expand All @@ -68,50 +69,91 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference]))
}

/**
* Accumulates all the CTEs from a plan into a special map.
*
* @param plan The plan to collect the CTEs from
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. The value of the map is tuple whose elements are:
* - The CTE definition
* - The number of incoming references to the CTE. This includes references from
* other CTEs and regular places.
* - A mutable inner map that tracks outgoing references (counts) to other CTEs.
* @param outerCTEId While collecting the map we use this optional CTE id to identify the
* current outer CTE.
*/
def buildCTEMap(
plan: LogicalPlan,
cteMap: mutable.HashMap[Long, (CTERelationDef, Int)]): Unit = {
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
outerCTEId: Option[Long] = None): Unit = {
plan match {
case WithCTE(_, cteDefs) =>
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0))
}
cteDefs.foreach { cteDef =>
cteMap.put(cteDef.id, (cteDef, 0))
buildCTEMap(cteDef, cteMap, Some(cteDef.id))
}
buildCTEMap(child, cteMap, outerCTEId)

case ref: CTERelationRef =>
val (cteDef, refCount) = cteMap(ref.cteId)
cteMap.update(ref.cteId, (cteDef, refCount + 1))
val (cteDef, refCount, refMap) = cteMap(ref.cteId)
cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
outerCTEId.foreach { cteId =>
val (_, _, outerRefMap) = cteMap(cteId)
outerRefMap(ref.cteId) += 1
}

case _ =>
}

if (plan.containsPattern(CTE)) {
plan.children.foreach { child =>
buildCTEMap(child, cteMap)
}
if (plan.containsPattern(CTE)) {
plan.children.foreach { child =>
buildCTEMap(child, cteMap, outerCTEId)
}

plan.expressions.foreach { expr =>
if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
expr.foreach {
case e: SubqueryExpression =>
buildCTEMap(e.plan, cteMap)
case _ =>
plan.expressions.foreach { expr =>
if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
expr.foreach {
case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, outerCTEId)
case _ =>
}
}
}
}
}
}

/**
* Cleans the CTE map by removing those CTEs that are not referenced at all and corrects those
* CTE's reference counts where the removed CTE referred to.
*
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. Needs to be sorted to speed up cleaning.
*/
private def cleanCTEMap(
cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
) = {
cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
val (_, currentRefCount, refMap) = cteMap(currentCTEId)
if (currentRefCount == 0) {
refMap.foreach { case (referencedCTEId, uselessRefCount) =>
val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
}
}
}
}

private def inlineCTE(
plan: LogicalPlan,
cteMap: mutable.HashMap[Long, (CTERelationDef, Int)],
cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
val (cte, refCount) = cteMap(cteDef.id)
val (cte, refCount, refMap) = cteMap(cteDef.id)
if (refCount > 0) {
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
cteMap.update(cteDef.id, (inlined, refCount))
cteMap(cteDef.id) = (inlined, refCount, refMap)
if (!shouldInline(inlined, refCount)) {
notInlined.append(inlined)
}
Expand All @@ -120,7 +162,7 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
inlineCTE(child, cteMap, notInlined)

case ref: CTERelationRef =>
val (cteDef, refCount) = cteMap(ref.cteId)
val (cteDef, refCount, _) = cteMap(ref.cteId)
if (shouldInline(cteDef, refCount)) {
if (ref.outputSet == cteDef.outputSet) {
cteDef.child
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4648,6 +4648,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
sql("SELECT /*+ hash(t2) */ * FROM t1 join t2 on c1 = c2")
}
}

test("SPARK-43199: InlineCTE is idempotent") {
sql(
"""
|WITH
| x(r) AS (SELECT random()),
| y(r) AS (SELECT * FROM x),
| z(r) AS (SELECT * FROM x)
|SELECT * FROM z
|""".stripMargin).collect()
}
}

case class Foo(bar: Option[String])

0 comments on commit 8970415

Please sign in to comment.