Skip to content

Commit

Permalink
Simplify fix by cloning properties on inherit
Browse files Browse the repository at this point in the history
The fix for SPARK-10548 can be simplified by just cloning the
parent properties on inherit rather than excluding specific
properties from ever being inherited. This is safe because the
child thread must be created BEFORE the parent thread runs a
query.
  • Loading branch information
Andrew Or committed Sep 14, 2015
1 parent 35bb6f0 commit 984a92f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 103 deletions.
22 changes: 4 additions & 18 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ import java.util.UUID.randomUUID
import scala.collection.JavaConverters._
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
import scala.collection.mutable.{HashMap, HashSet}
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
import scala.util.control.NonFatal

import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
Expand Down Expand Up @@ -347,30 +348,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] var checkpointDir: Option[String] = None

// Thread Local variable that can be used by users to pass information down the stack
private val localProperties = new InheritableThreadLocal[Properties] {
protected[spark] val localProperties = new InheritableThreadLocal[Properties] {
override protected def childValue(parent: Properties): Properties = {
// Note: make a clone such that changes in the parent properties aren't reflected in
// the those of the children threads, which has confusing semantics (SPARK-10564).
val p = new Properties
val filtered = parent.asScala.filter { case (k, _) =>
!nonInheritedLocalProperties.contains(k)
}
p.putAll(filtered.asJava)
p
SerializationUtils.clone(parent).asInstanceOf[Properties]
}
override protected def initialValue(): Properties = new Properties()
}

// Keys of local properties that should not be inherited by children threads
private val nonInheritedLocalProperties: HashSet[String] = new HashSet[String]

/**
* Mark a local property such that its values are never inherited across the thread hierarchy.
*/
private[spark] def markLocalPropertyNonInherited(key: String): Unit = {
nonInheritedLocalProperties += key
}

/* ------------------------------------------------------------------------------------- *
| Initialization. This code initializes the context in a manner that is exception-safe. |
| All internal fields holding state are initialized here, and any error prompts the |
Expand Down
84 changes: 25 additions & 59 deletions core/src/test/scala/org/apache/spark/ThreadingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
ThreadingSuiteState.runningThreads.get() + "); failing test")
fail("One or more threads didn't see runningThreads = 4")
}
throwable.foreach { t => throw t }
throwable.foreach { t => throw improveStackTrace(t) }
}

test("set local properties in different thread") {
Expand All @@ -179,7 +179,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {

sem.acquire(5)
assert(sc.getLocalProperty("test") === null)
throwable.foreach { t => throw t }
throwable.foreach { t => throw improveStackTrace(t) }
}

test("set and get local properties in parent-children thread") {
Expand Down Expand Up @@ -209,73 +209,39 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
sem.acquire(5)
assert(sc.getLocalProperty("test") === "parent")
assert(sc.getLocalProperty("Foo") === null)
throwable.foreach { t => throw t }
throwable.foreach { t => throw improveStackTrace(t) }
}

test("inheritance exclusions (SPARK-10548)") {
test("mutation in parent local property does not affect child (SPARK-10563)") {
sc = new SparkContext("local", "test")
sc.markLocalPropertyNonInherited("do-not-inherit-me")
sc.setLocalProperty("do-inherit-me", "parent")
sc.setLocalProperty("do-not-inherit-me", "parent")
val originalTestValue: String = "original-value"
var threadTestValue: String = null
sc.setLocalProperty("test", originalTestValue)
var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// only the ones we intend to inherit will be passed to the children
try {
assert(sc.getLocalProperty("do-inherit-me") === "parent")
assert(sc.getLocalProperty("do-not-inherit-me") === null)
} catch {
case t: Throwable => throwable = Some(t)
}
}
}
}
threads.foreach(_.start())
threads.foreach(_.join())
throwable.foreach { t => throw t }
}

test("mutations to local properties should not affect submitted jobs (SPARK-6629)") {
val jobStarted = new Semaphore(0)
val jobEnded = new Semaphore(0)
@volatile var jobResult: JobResult = null

sc = new SparkContext("local", "test")
sc.setJobGroup("originalJobGroupId", "description")
sc.addSparkListener(new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobStarted.release()
}
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobResult = jobEnd.jobResult
jobEnded.release()
}
})

// Create a new thread which will inherit the current thread's properties
val thread = new Thread() {
val thread = new Thread {
override def run(): Unit = {
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId")
// Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task
try {
sc.parallelize(1 to 100).foreach { x =>
Thread.sleep(100)
}
threadTestValue = sc.getLocalProperty("test")
} catch {
case s: SparkException => // ignored so that we don't print noise in test logs
case t: Throwable =>
throwable = Some(t)
}
}
}
sc.setLocalProperty("test", "this-should-not-be-inherited")
thread.start()
// Wait for the job to start, then mutate the original properties, which should have been
// inherited by the running job but hopefully defensively copied or snapshotted:
jobStarted.tryAcquire(10, TimeUnit.SECONDS)
sc.setJobGroup("modifiedJobGroupId", "description")
// Canceling the original job group should cancel the running job. In other words, the
// modification of the properties object should not affect the properties of running jobs
sc.cancelJobGroup("originalJobGroupId")
jobEnded.tryAcquire(10, TimeUnit.SECONDS)
assert(jobResult.isInstanceOf[JobFailed])
thread.join()
throwable.foreach { t => throw improveStackTrace(t) }
assert(threadTestValue === originalTestValue)
}

/**
* Improve the stack trace of an error thrown from within a thread.
* Otherwise it's difficult to tell which line in the test the error came from.
*/
private def improveStackTrace(t: Throwable): Throwable = {
t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace)
t
}

}
5 changes: 0 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
sparkContext.addSparkListener(listener)
sparkContext.ui.foreach(new SQLTab(this, _))

// Ensure query execution IDs are not inherited across the thread hierarchy, which is
// the default behavior for SparkContext local properties. Otherwise, we may confuse
// the listener as to which query is being executed. (SPARK-10548)
sparkContext.markLocalPropertyNonInherited(SQLExecution.EXECUTION_ID_KEY)

/**
* Set Spark SQL configuration properties.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,85 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.test.SharedSQLContext
import java.util.Properties

class SQLExecutionSuite extends SharedSQLContext {
import testImplicits._
import scala.collection.parallel.CompositeThrowable

test("query execution IDs are not inherited across threads") {
sparkContext.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "123")
sparkContext.setLocalProperty("do-inherit-me", "some-value")
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.SQLContext

class SQLExecutionSuite extends SparkFunSuite {

test("concurrent query execution (SPARK-10548)") {
// Try to reproduce the issue with the old SparkContext
val conf = new SparkConf()
.setMaster("local[*]")
.setAppName("test")
val badSparkContext = new BadSparkContext(conf)
try {
testConcurrentQueryExecution(badSparkContext)
fail("unable to reproduce SPARK-10548")
} catch {
case e: IllegalArgumentException =>
assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
} finally {
badSparkContext.stop()
}

// Verify that the issue is fixed with the latest SparkContext
val goodSparkContext = new SparkContext(conf)
try {
testConcurrentQueryExecution(goodSparkContext)
} finally {
goodSparkContext.stop()
}
}

/**
* Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently.
*/
private def testConcurrentQueryExecution(sc: SparkContext): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// Initialize local properties. This is necessary for the test to pass.
sc.getLocalProperties

// Set up a thread that runs executes a simple SQL query.
// Before starting the thread, mutate the execution ID in the parent.
// The child thread should not see the effect of this change.
var throwable: Option[Throwable] = None
val thread = new Thread {
val child = new Thread {
override def run(): Unit = {
try {
assert(sparkContext.getLocalProperty("do-inherit-me") === "some-value")
assert(sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) === null)
sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect()
} catch {
case t: Throwable =>
throwable = Some(t)
}

}
}
thread.start()
thread.join()
throwable.foreach { t => throw t }
}
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything")
child.start()
child.join()

// This is the end-to-end version of the previous test.
test("parallel query execution (SPARK-10548)") {
(1 to 5).foreach { i =>
// Scala's parallel collections spawns new threads as children of the existing threads.
// We need to run this multiple times to ensure new threads are spawned. Without the fix
// for SPARK-10548, this usually fails on the second try.
val df = sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b")
(1 to 10).par.foreach { _ => df.count() }
// The throwable is thrown from the child thread so it doesn't have a helpful stack trace
throwable.foreach { t =>
t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace)
throw t
}
}

}

/**
* A bad [[SparkContext]] that does not clone the inheritable thread local properties
* when passing them to children threads.
*/
private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
override protected def initialValue(): Properties = new Properties()
}
}

0 comments on commit 984a92f

Please sign in to comment.