diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ec07aab359ac6..be35916e3447e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext @@ -32,6 +33,12 @@ object SQLExecution { private def nextExecutionId: Long = _nextExecutionId.getAndIncrement + private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() + + def getQueryExecution(executionId: Long): QueryExecution = { + executionIdToQueryExecution.get(executionId) + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. @@ -44,6 +51,7 @@ object SQLExecution { if (oldExecutionId == null) { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) val r = try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on @@ -60,6 +68,7 @@ object SQLExecution { executionId, System.currentTimeMillis())) } } finally { + executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } r diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index ad41111bec9d6..b0597067839dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -102,6 +103,33 @@ class SQLExecutionSuite extends SparkFunSuite { } } + + test("Finding QueryExecution for given executionId") { + val spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate() + import spark.implicits._ + + var queryExecution: QueryExecution = null + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdStr = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdStr != null) { + queryExecution = SQLExecution.getQueryExecution(executionIdStr.toLong) + } + SQLExecutionSuite.canProgress = true + } + }) + + val df = spark.range(1).map { x => + while (!SQLExecutionSuite.canProgress) { + Thread.sleep(1) + } + x + } + df.collect() + + assert(df.queryExecution === queryExecution) + } } /** @@ -114,3 +142,7 @@ private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { override protected def initialValue(): Properties = new Properties() } } + +object SQLExecutionSuite { + @volatile var canProgress = false +}