Skip to content

Commit

Permalink
Merge pull request alteryx#67 from kayousterhout/remove_tsl
Browse files Browse the repository at this point in the history
Removed TaskSchedulerListener interface.

The interface was used only by the DAG scheduler (so it wasn't necessary
to define the additional interface), and the naming makes it very
confusing when reading the code (because "listener" was used
to describe the DAG scheduler, rather than SparkListeners, which
implement a nearly-identical interface but serve a different
function).

@mateiz - is there a reason for this interface that I'm missing?
  • Loading branch information
mateiz committed Oct 17, 2013
2 parents f9973ca + 809f547 commit cf64f63
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,20 @@ class DAGScheduler(
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends TaskSchedulerListener with Logging {
extends Logging {

def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setListener(this)
taskSched.setDAGScheduler(this)

// Called by TaskScheduler to report task's starting.
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}

// Called by TaskScheduler to report task completions or failures.
override def taskEnded(
def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
Expand All @@ -79,18 +79,18 @@ class DAGScheduler(
}

// Called by TaskScheduler when an executor fails.
override def executorLost(execId: String) {
def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}

// Called by TaskScheduler when a host is added
override def executorGained(execId: String, host: String) {
def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}

// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
* are failures, and mitigating stragglers. They return events to the DAGScheduler through
* the TaskSchedulerListener interface.
* are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {

Expand All @@ -48,8 +47,8 @@ private[spark] trait TaskScheduler {
// Cancel a stage.
def cancelTasks(stageId: Int)

// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]

// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null
var dagScheduler: DAGScheduler = null

var backend: SchedulerBackend = null

Expand All @@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)

override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}

def initialize(context: SchedulerBackend) {
Expand Down Expand Up @@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
Expand Down Expand Up @@ -397,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
// Call listener.executorLost without holding the lock on this to prevent deadlock
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
Expand All @@ -418,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}

def executorGained(execId: String, host: String) {
listener.executorGained(execId, host)
dagScheduler.executorGained(execId, host)
}

def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,11 @@ private[spark] class ClusterTaskSetManager(
}

private def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info)
sched.dagScheduler.taskStarted(task, info)
}

/**
* Marks the task as successful and notifies the listener that a task has ended.
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
Expand All @@ -429,7 +429,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
sched.listener.taskEnded(
sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)

// Mark successful and stop if all the tasks have succeeded.
Expand All @@ -445,7 +445,8 @@ private[spark] class ClusterTaskSetManager(
}

/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
Expand All @@ -463,7 +464,7 @@ private[spark] class ClusterTaskSetManager(
reason.foreach {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true
tasksSuccessful += 1
sched.taskSetFinished(this)
Expand All @@ -472,11 +473,11 @@ private[spark] class ClusterTaskSetManager(

case TaskKilled =>
logWarning("Task %d was killed.".format(tid))
sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null)
sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return

case ef: ExceptionFailure =>
sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
Expand Down Expand Up @@ -504,7 +505,7 @@ private[spark] class ClusterTaskSetManager(

case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host))
sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)

case _ => {}
}
Expand Down Expand Up @@ -533,7 +534,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
Expand Down Expand Up @@ -606,7 +607,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:

val env = SparkEnv.get
val attemptId = new AtomicInteger
var listener: TaskSchedulerListener = null
var dagScheduler: DAGScheduler = null

// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
Expand Down Expand Up @@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}

override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}

override def submitTasks(taskSet: TaskSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}

def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info)
sched.dagScheduler.taskStarted(task, info)
}

def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
Expand All @@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
Expand All @@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
Expand All @@ -176,15 +177,15 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
sched.listener.taskSetFailed(taskSet, errorMessage)
sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
}
}

override def error(message: String) {
sched.listener.taskSetFailed(taskSet, message)
sched.dagScheduler.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets += taskSet
}
override def cancelTasks(stageId: Int) {}
override def setListener(listener: TaskSchedulerListener) = {}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}

class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
taskScheduler.startedTasks += taskInfo.index
}

override def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: mutable.Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
taskScheduler.endedTasks(taskInfo.index) = reason
}

override def executorGained(execId: String, host: String) {}

override def executorLost(execId: String) {}

override def taskSetFailed(taskSet: TaskSet, reason: String) {
taskScheduler.taskSetsFailed += taskSet.id
}
}

/**
* A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
Expand All @@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*

val executors = new mutable.HashMap[String, String] ++ liveExecutors

listener = new TaskSchedulerListener {
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
startedTasks += taskInfo.index
}

def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: mutable.Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics)
{
endedTasks(taskInfo.index) = reason
}

def executorGained(execId: String, host: String) {}

def executorLost(execId: String) {}

def taskSetFailed(taskSet: TaskSet, reason: String) {
taskSetsFailed += taskSet.id
}
}
dagScheduler = new FakeDAGScheduler(this)

def removeExecutor(execId: String): Unit = executors -= execId

Expand Down

0 comments on commit cf64f63

Please sign in to comment.