Skip to content

Commit

Permalink
[VL] Prepare shim API for breaking change in SPARK-48610 (apache#7445)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Oct 10, 2024
1 parent a6c0798 commit 16bc416
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.gluten.execution.WholeStageTransformer
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -49,7 +50,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
p: SparkPlan,
reason: String,
fallbackNodeToReason: mutable.HashMap[String, String]): Unit = {
p.getTagValue(QueryPlan.OP_ID_TAG).foreach {
SparkShimLoader.getSparkShims.getOperatorId(p).foreach {
opId =>
// e.g., 002 project, it is used to help analysis by `substring(4)`
val formattedNodeName = f"$opId%03d ${p.nodeName}"
Expand Down Expand Up @@ -150,94 +151,99 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
// scalastyle:off
/**
* Given a input physical plan, performs the following tasks.
* 1. Generates the explain output for the input plan excluding the subquery plans.
* 2. Generates the explain output for each subquery referenced in the plan.
* 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates
* the explain output for each subquery referenced in the plan.
*/
// scalastyle:on
// spotless:on
def processPlan[T <: QueryPlan[T]](
plan: T,
append: String => Unit,
collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized {
try {
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow
// intentional overwriting of IDs generated in previous AQE iteration
val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap())
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
// Exchanges as part of SPARK-42753
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]
collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo =
synchronized {
SparkShimLoader.getSparkShims.withOperatorIdMap(
new java.util.IdentityHashMap[QueryPlan[_], Int]()) {
try {
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to
// allow intentional overwriting of IDs generated in previous AQE iteration
val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap())
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
// Exchanges as part of SPARK-42753
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]

var currentOperatorID = 0
currentOperatorID =
generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true)
var currentOperatorID = 0
currentOperatorID =
generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true)

val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
getSubqueries(plan, subqueries)
val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
getSubqueries(plan, subqueries)

currentOperatorID = subqueries.foldLeft(currentOperatorID) {
(curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true)
}
currentOperatorID = subqueries.foldLeft(currentOperatorID) {
(curId, plan) =>
generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true)
}

// SPARK-42753: Process subtree for a ReusedExchange with unknown child
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
reusedExchanges.foreach {
reused =>
val child = reused.child
if (!operators.contains(child)) {
optimizedOutExchanges.append(child)
currentOperatorID =
generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false)
// SPARK-42753: Process subtree for a ReusedExchange with unknown child
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
reusedExchanges.foreach {
reused =>
val child = reused.child
if (!operators.contains(child)) {
optimizedOutExchanges.append(child)
currentOperatorID =
generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false)
}
}
}

val collectedOperators = BitSet.empty
processPlanSkippingSubqueries(plan, append, collectedOperators)
val collectedOperators = BitSet.empty
processPlanSkippingSubqueries(plan, append, collectedOperators)

var i = 0
for (sub <- subqueries) {
if (i == 0) {
append("\n===== Subqueries =====\n\n")
}
i = i + 1
append(
s"Subquery:$i Hosting operator id = " +
s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n")
var i = 0
for (sub <- subqueries) {
if (i == 0) {
append("\n===== Subqueries =====\n\n")
}
i = i + 1
append(
s"Subquery:$i Hosting operator id = " +
s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n")

// For each subquery expression in the parent plan, process its child plan to compute
// the explain output. In case of subquery reuse, we don't print subquery plan more
// than once. So we skip [[ReusedSubqueryExec]] here.
if (!sub._3.isInstanceOf[ReusedSubqueryExec]) {
processPlanSkippingSubqueries(sub._3.child, append, collectedOperators)
}
append("\n")
}
// For each subquery expression in the parent plan, process its child plan to compute
// the explain output. In case of subquery reuse, we don't print subquery plan more
// than once. So we skip [[ReusedSubqueryExec]] here.
if (!sub._3.isInstanceOf[ReusedSubqueryExec]) {
processPlanSkippingSubqueries(sub._3.child, append, collectedOperators)
}
append("\n")
}

i = 0
optimizedOutExchanges.foreach {
exchange =>
if (i == 0) {
append("\n===== Adaptively Optimized Out Exchanges =====\n\n")
i = 0
optimizedOutExchanges.foreach {
exchange =>
if (i == 0) {
append("\n===== Adaptively Optimized Out Exchanges =====\n\n")
}
i = i + 1
append(s"Subplan:$i\n")
processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators)
append("\n")
}
i = i + 1
append(s"Subplan:$i\n")
processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators)
append("\n")
}

(subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan)
.map {
plan =>
if (collectFallbackFunc.isEmpty) {
collectFallbackNodes(plan)
} else {
collectFallbackFunc.get.apply(plan)
(subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan)
.map {
plan =>
if (collectFallbackFunc.isEmpty) {
collectFallbackNodes(plan)
} else {
collectFallbackFunc.get.apply(plan)
}
}
.reduce((a, b) => (a._1 + b._1, a._2 ++ b._2))
} finally {
removeTags(plan)
}
.reduce((a, b) => (a._1 + b._1, a._2 ++ b._2))
} finally {
removeTags(plan)
}
}
}
// scalastyle:on
// spotless:on

/**
* Traverses the supplied input plan in a bottom-up fashion and records the operator id via
Expand Down Expand Up @@ -288,7 +294,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
}
visited.add(plan)
currentOperationID += 1
plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID)
SparkShimLoader.getSparkShims.setOperatorId(plan, currentOperationID)
}

plan.foreachUp {
Expand Down Expand Up @@ -358,12 +364,12 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
* value.
*/
private def getOpId(plan: QueryPlan[_]): String = {
plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown")
SparkShimLoader.getSparkShims.getOperatorId(plan).map(v => s"$v").getOrElse("unknown")
}

private def removeTags(plan: QueryPlan[_]): Unit = {
def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = {
p.unsetTagValue(QueryPlan.OP_ID_TAG)
SparkShimLoader.getSparkShims.unsetOperatorId(p)
children.foreach(removeTags)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf
import org.apache.spark.sql.execution.GlutenExplainUtils._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
Expand All @@ -42,8 +41,8 @@ import scala.collection.mutable.ArrayBuffer
* A helper class to get the Gluten fallback summary from a Spark [[Dataset]].
*
* Note that, if AQE is enabled, but the query is not materialized, then this method will re-plan
* the query execution with disabled AQE. It is a workaround to get the final plan, and it may
* cause the inconsistent results with a materialized query. However, we have no choice.
* the query execution with disabled AQE. It is a workaround to get the final plan, and it may cause
* the inconsistent results with a materialized query. However, we have no choice.
*
* For example:
*
Expand Down Expand Up @@ -96,7 +95,9 @@ object GlutenImplicits {
args.substring(index + "isFinalPlan=".length).trim.toBoolean
}

private def collectFallbackNodes(spark: SparkSession, plan: QueryPlan[_]): FallbackInfo = {
private def collectFallbackNodes(
spark: SparkSession,
plan: QueryPlan[_]): GlutenExplainUtils.FallbackInfo = {
var numGlutenNodes = 0
val fallbackNodeToReason = new mutable.HashMap[String, String]

Expand Down Expand Up @@ -131,7 +132,7 @@ object GlutenImplicits {
spark,
newSparkPlan
)
processPlan(
GlutenExplainUtils.processPlan(
newExecutedPlan,
new PlanStringConcat().append,
Some(plan => collectFallbackNodes(spark, plan)))
Expand All @@ -146,12 +147,15 @@ object GlutenImplicits {
if (PlanUtil.isGlutenTableCache(i)) {
numGlutenNodes += 1
} else {
addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason)
GlutenExplainUtils.addFallbackNodeWithReason(
i,
"Columnar table cache is disabled",
fallbackNodeToReason)
}
collect(i.relation.cachedPlan)
case _: AQEShuffleReadExec => // Ignore
case p: SparkPlan =>
handleVanillaSparkPlan(p, fallbackNodeToReason)
GlutenExplainUtils.handleVanillaSparkPlan(p, fallbackNodeToReason)
p.innerChildren.foreach(collect)
case _ =>
}
Expand Down Expand Up @@ -181,10 +185,10 @@ object GlutenImplicits {
// AQE is not materialized, so the columnar rules are not applied.
// For this case, We apply columnar rules manually with disable AQE.
val qe = spark.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP)
processPlan(qe.executedPlan, concat.append, collectFallbackFunc)
GlutenExplainUtils.processPlan(qe.executedPlan, concat.append, collectFallbackFunc)
}
} else {
processPlan(plan, concat.append, collectFallbackFunc)
GlutenExplainUtils.processPlan(plan, concat.append, collectFallbackFunc)
}
totalNumGlutenNodes += numGlutenNodes
totalNumFallbackNodes += fallbackNodeToReason.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -270,4 +271,18 @@ trait SparkShims {
def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
throw new UnsupportedOperationException("ArrayInsert not supported.")
}

/** Shim method for usages from GlutenExplainUtils.scala. */
def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = {
body
}

/** Shim method for usages from GlutenExplainUtils.scala. */
def getOperatorId(plan: QueryPlan[_]): Option[Int]

/** Shim method for usages from GlutenExplainUtils.scala. */
def setOperatorId(plan: QueryPlan[_], opId: Int): Unit

/** Shim method for usages from GlutenExplainUtils.scala. */
def unsetOperatorId(plan: QueryPlan[_]): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -283,4 +284,16 @@ class Spark32Shims extends SparkShims {
val s = decimalType.scale
DecimalType(p, if (toScale > s) s else toScale)
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, RegrR2, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -364,4 +365,16 @@ class Spark33Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -499,4 +500,16 @@ class Spark34Shims extends SparkShims {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Loading

0 comments on commit 16bc416

Please sign in to comment.