Skip to content

Commit

Permalink
code refactoring and adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
kanzhang committed Apr 9, 2014
1 parent b073ee6 commit cd0629f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 19 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -882,14 +882,14 @@ class SparkContext(
metadataCleaner.cancel()
cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
listenerBus.stop()
eventLogger.foreach(_.stop())
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
} else {
logInfo("SparkContext already stopped")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,19 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
private var started = false
private var sparkListenerBus: Option[Thread] = _
private val listenerThread = new Thread("SparkListenerBus") {
setDaemon(true)
override def run() {
while (true) {
val event = eventQueue.take
if (event == SparkListenerShutdown) {
// Get out of the while loop and shutdown the daemon thread
return
}
postToAll(event)
}
}
}

/**
* Start sending events to attached listeners.
Expand All @@ -49,21 +61,8 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
if (started) {
throw new IllegalStateException("Listener bus already started!")
}
listenerThread.start()
started = true
sparkListenerBus = Some(new Thread("SparkListenerBus") {
setDaemon(true)
override def run() {
while (true) {
val event = eventQueue.take
if (event == SparkListenerShutdown) {
// Get out of the while loop and shutdown the daemon thread
return
}
postToAll(event)
}
}
})
sparkListenerBus.foreach(_.start())
}

def post(event: SparkListenerEvent) {
Expand Down Expand Up @@ -99,6 +98,6 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")
}
post(SparkListenerShutdown)
sparkListenerBus.foreach(_.join())
listenerThread.join()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.scheduler

import java.util.concurrent.Semaphore

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
Expand Down Expand Up @@ -72,6 +74,53 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}

test("bus.stop() waits for the event queue to completely drain") {
@volatile var drained = false

class BlockingListener(cond: AnyRef) extends SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd) = {
cond.synchronized { cond.wait() }
drained = true
}
}

val bus = new LiveListenerBus
val blockingListener = new BlockingListener(bus)
val sem = new Semaphore(0)

bus.addListener(blockingListener)
bus.post(SparkListenerJobEnd(0, JobSucceeded))
bus.start()
// the queue should not drain immediately
assert(!drained)

new Thread("ListenerBusStopper") {
override def run() {
// stop() would block until notify() is called below
bus.stop()
sem.release()
}
}.start()

val startTime = System.currentTimeMillis()
val waitTime = 100
var done = false
while (!done) {
if (System.currentTimeMillis() > startTime + waitTime) {
bus.synchronized {
bus.notify()
}
done = true
} else {
Thread.sleep(10)
// bus.stop() should wait until the event queue is drained
assert(!drained)
}
}
sem.acquire()
assert(drained)
}

test("basic creation of StageInfo") {
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,5 @@ object SparkHdfsLR {

println("Final w: " + w)
sc.stop()
System.exit(0)
}
}

0 comments on commit cd0629f

Please sign in to comment.