diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 08cb6c0134e3a..ac44f08897cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -102,6 +102,13 @@ case class PredicateSubquery( override def nullable: Boolean = nullAware override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) + override def semanticEquals(o: Expression): Boolean = o match { + case p: PredicateSubquery => + query.sameResult(p.query) && nullAware == p.nullAware && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8bce404735785..24a2dc9d3b35f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -538,9 +538,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ false :+ false, builder, verbose)) + depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ false :+ true, builder, verbose) + depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose) } if (children.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5b9af26dfc4f8..d4845637be049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -101,7 +101,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), - ReuseExchange(sparkSession.sessionState.conf)) + ReuseExchange(sparkSession.sessionState.conf), + ReuseSubquery(sparkSession.sessionState.conf)) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 79cb40948b982..7f2e18586d347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -142,21 +142,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * This list is populated by [[prepareSubqueries]], which is called in [[prepare]]. */ @transient - private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])] + private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression] /** * Finds scalar subquery expressions in this plan node and starts evaluating them. - * The list of subqueries are added to [[subqueryResults]]. */ protected def prepareSubqueries(): Unit = { - val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) - allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => - val futureResult = Future { - // Each subquery should return only one row (and one column). We take two here and throws - // an exception later if the number of rows is greater than one. - e.executedPlan.executeTake(2) - }(SparkPlan.subqueryExecutionContext) - subqueryResults += e -> futureResult + expressions.foreach { + _.collect { + case e: ExecSubqueryExpression => + e.plan.prepare() + runningSubqueries += e + } } } @@ -165,21 +162,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries - subqueryResults.foreach { case (e, futureResult) => - val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) - if (rows.length > 1) { - sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") - } - if (rows.length == 1) { - assert(rows(0).numFields == 1, - s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") - e.updateResult(rows(0).get(0, e.dataType)) - } else { - // If there is no rows returned, the result should be null. - e.updateResult(null) - } + runningSubqueries.foreach { sub => + sub.updateResult() } - subqueryResults.clear() + runningSubqueries.clear() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e6f7081f2916d..ad8a71689895b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.execution +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration + +import org.apache.spark.SparkException import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates +import org.apache.spark.sql.types.LongType +import org.apache.spark.util.ThreadUtils import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ @@ -502,15 +508,64 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa /** * Physical plan for a subquery. - * - * This is used to generate tree string for SparkScalarSubquery. */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { + + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), + "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def sameResult(o: SparkPlan): Boolean = o match { + case s: SubqueryExec => child.sameResult(s.child) + case _ => false + } + + @transient + private lazy val relationFuture: Future[Array[InternalRow]] = { + // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + val beforeCollect = System.nanoTime() + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val rows: Array[InternalRow] = child.executeCollect() + val beforeBuild = System.nanoTime() + longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 + val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( + executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) + } + + rows + } + }(SubqueryExec.executionContext) + } + + protected override def doPrepare(): Unit = { + relationFuture + } + protected override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + child.execute() } + + override def executeCollect(): Array[InternalRow] = { + ThreadUtils.awaitResult(relationFuture, Duration.Inf) + } +} + +object SubqueryExec { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 461d3010ada7e..c730bee6ae050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,14 +17,38 @@ package org.apache.spark.sql.execution +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, DataType, StructType} + +/** + * The base class for subquery that is used in SparkPlan. + */ +trait ExecSubqueryExpression extends SubqueryExpression { + + val executedPlan: SubqueryExec + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression + + // does not have logical plan + override def query: LogicalPlan = throw new UnsupportedOperationException + override def withNewPlan(plan: LogicalPlan): SubqueryExpression = + throw new UnsupportedOperationException + + override def plan: SparkPlan = executedPlan + + /** + * Fill the expression with collected result from executed plan. + */ + def updateResult(): Unit +} /** * A subquery that will return only one row and one column. @@ -32,27 +56,39 @@ import org.apache.spark.sql.types.DataType * This is the physical copy of ScalarSubquery to be used inside SparkPlan. */ case class ScalarSubquery( - executedPlan: SparkPlan, + executedPlan: SubqueryExec, exprId: ExprId) - extends SubqueryExpression { - - override def query: LogicalPlan = throw new UnsupportedOperationException - override def withNewPlan(plan: LogicalPlan): SubqueryExpression = { - throw new UnsupportedOperationException - } - override def plan: SparkPlan = SubqueryExec(simpleString, executedPlan) + extends ExecSubqueryExpression { override def dataType: DataType = executedPlan.schema.fields.head.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true - override def toString: String = s"subquery#${exprId.id}" + override def toString: String = executedPlan.simpleString + + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + + override def semanticEquals(other: Expression): Boolean = other match { + case s: ScalarSubquery => executedPlan.sameResult(executedPlan) + case _ => false + } // the first column in first row from `query`. @volatile private var result: Any = null @volatile private var updated: Boolean = false - def updateResult(v: Any): Unit = { - result = v + def updateResult(): Unit = { + val rows = plan.executeCollect() + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, + s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") + result = rows(0).get(0, dataType) + } else { + // If there is no rows returned, the result should be null. + result = null + } updated = true } @@ -67,6 +103,51 @@ case class ScalarSubquery( } } +/** + * A subquery that will check the value of `child` whether is in the result of a query or not. + */ +case class InSubquery( + child: Expression, + executedPlan: SubqueryExec, + exprId: ExprId, + private var result: Array[Any] = null, + private var updated: Boolean = false) extends ExecSubqueryExpression { + + override def dataType: DataType = BooleanType + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = child.nullable + override def toString: String = s"$child IN ${executedPlan.name}" + + def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan) + + override def semanticEquals(other: Expression): Boolean = other match { + case in: InSubquery => child.semanticEquals(in.child) && + executedPlan.sameResult(in.executedPlan) + case _ => false + } + + def updateResult(): Unit = { + val rows = plan.executeCollect() + result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]] + updated = true + } + + override def eval(input: InternalRow): Any = { + require(updated, s"$this has not finished") + val v = child.eval(input) + if (v == null) { + null + } else { + result.contains(v) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + require(updated, s"$this has not finished") + InSet(child, result.toSet).doGenCode(ctx, ev) + } +} + /** * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ @@ -75,7 +156,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan - ScalarSubquery(executedPlan, subquery.exprId) + ScalarSubquery( + SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), + subquery.exprId) + case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) => + val executedPlan = new QueryExecution(sparkSession, plan).executedPlan + InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) + } + } +} + + +/** + * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * references. + */ +case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.exchangeReuseEnabled) { + return plan + } + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() + plan transformAllExpressions { + case sub: ExecSubqueryExpression => + val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]()) + val sameResult = sameSchema.find(_.sameResult(sub.plan)) + if (sameResult.isDefined) { + sub.withExecutedPlan(sameResult.get) + } else { + sameSchema += sub.executedPlan + sub + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 4bb9d6fef4c1d..9d4ebcce4d103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -99,7 +99,11 @@ object SparkPlanGraph { case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) - case "ReusedExchange" => + case "Subquery" if exchanges.contains(planInfo) => + // Point to the re-used subquery + val node = exchanges(planInfo) + edges += SparkPlanGraphEdge(node.id, parent.id) + case "ReusedExchange" if exchanges.contains(planInfo.children.head) => // Point to the re-used exchange val node = exchanges(planInfo.children.head) edges += SparkPlanGraphEdge(node.id, parent.id) @@ -115,7 +119,7 @@ object SparkPlanGraph { } else { subgraph.nodes += node } - if (name.contains("Exchange")) { + if (name.contains("Exchange") || name == "Subquery") { exchanges += planInfo -> node }