diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e27b3c4962221..ca2c2add3b449 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -347,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index a96a4ce201c21..54c131cdae367 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,7 +147,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -178,7 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -207,58 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - var throwable: Option[Throwable] = None - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { try { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task - try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } - } catch { - case s: SparkException => // ignored so that we don't print noise in test logs - } + threadTestValue = sc.getLocalProperty("test") } catch { case t: Throwable => throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - throwable.foreach { t => throw t } - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 0000000000000..63639681ef80a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +}