Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
fixes #902 - optimized task queueing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
drexin committed Jan 12, 2015
1 parent a714f5d commit 646e0fd
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 111 deletions.
42 changes: 12 additions & 30 deletions src/main/scala/mesosphere/marathon/MarathonScheduler.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package mesosphere.marathon

import java.util
import javax.inject.{ Inject, Named }

import akka.actor.{ ActorRef, ActorSystem }
Expand Down Expand Up @@ -97,52 +98,34 @@ class MarathonScheduler @Inject() (
Await.result(appRepo.currentAppVersions(), config.zkTimeoutDuration)

taskQueue.retain {
case QueuedTask(app, _) =>
case QueuedTask(app, _, _) =>
appVersions.get(app.id) contains app.version
}

for (offer <- offers.asScala) {
try {
log.debug("Received offer %s".format(offer))

val queuedTasks: Seq[QueuedTask] = taskQueue.removeAll()

val withTaskInfos: collection.Seq[(QueuedTask, (TaskInfo, Seq[Long]))] =
queuedTasks.view.flatMap { case qt => newTask(qt.app, offer).map(qt -> _) }

val launchedTask = withTaskInfos.find {
case (qt, (taskInfo, ports)) =>
val timeLeft = qt.delay.timeLeft
if (timeLeft.toNanos <= 0) {
true
}
else {
log.info(s"Delaying task ${taskInfo.getTaskId.getValue} due to backoff. Time left: $timeLeft.")
false
}
val matchingTask = taskQueue.pollMatching { app =>
newTask(app, offer).map(app -> _)
}

launchedTask.foreach {
case (qt, (taskInfo, ports)) =>
val taskInfos = Seq(taskInfo)
log.debug("Launching tasks: " + taskInfos)

matchingTask.foreach {
case (app, (taskInfo, ports)) =>
val marathonTask = MarathonTasks.makeTask(
taskInfo.getTaskId.getValue, offer.getHostname, ports,
offer.getAttributesList.asScala, qt.app.version)
offer.getAttributesList.asScala, app.version)

log.debug("Launching task: " + taskInfo)

taskTracker.created(qt.app.id, marathonTask)
driver.launchTasks(Seq(offer.getId).asJava, taskInfos.asJava)
taskTracker.created(app.id, marathonTask)
driver.launchTasks(Seq(offer.getId).asJava, util.Arrays.asList(taskInfo))

// here it is assumed that the health checks for the current
// version are already running.
}

// put unscheduled tasks back in the queue
val launchedTaskSeq: Seq[QueuedTask] = launchedTask.map(_._1).to[Seq]
taskQueue.addAll(queuedTasks diff launchedTaskSeq)

if (launchedTask.isEmpty) {
if (matchingTask.isEmpty) {
log.debug("Offer doesn't match request. Declining.")
driver.declineOffer(offer.getId)
}
Expand Down Expand Up @@ -289,7 +272,6 @@ class MarathonScheduler @Inject() (
private def newTask(
app: AppDefinition,
offer: Offer): Option[(TaskInfo, Seq[Long])] = {
// TODO this should return a MarathonTask
new TaskBuilder(app, taskIdUtil.newTaskId, taskTracker, config, mapper).buildIfMatches(offer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ class SchedulerActions(

if (toQueue > 0) {
log.info(s"Queueing $toQueue new tasks for ${app.id} ($queuedCount queued)")
for (i <- 0 until toQueue)
taskQueue.add(app)
taskQueue.add(app, toQueue)
}
else {
log.info(s"Already queued $queuedCount tasks for ${app.id}. Not scaling.")
Expand Down
88 changes: 61 additions & 27 deletions src/main/scala/mesosphere/marathon/tasks/TaskQueue.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package mesosphere.marathon.tasks

import java.util.concurrent.PriorityBlockingQueue
import java.util.concurrent.atomic.AtomicInteger

import mesosphere.marathon.state.{ AppDefinition, PathId }
import mesosphere.marathon.state.{ Timestamp, AppDefinition, PathId }
import mesosphere.util.RateLimiter
import org.apache.log4j.Logger

import scala.annotation.tailrec
import scala.collection.concurrent.TrieMap
import scala.concurrent.duration.Deadline
import scala.collection.mutable
import scala.collection.immutable.Seq
import scala.collection.JavaConverters._

/**
* Utility class to stage tasks before they get scheduled
Expand All @@ -17,55 +18,88 @@ class TaskQueue {

import mesosphere.marathon.tasks.TaskQueue._

private val log = Logger.getLogger(getClass)
protected[marathon] val rateLimiter = new RateLimiter

// we used SynchronizedPriorityQueue before, but it has been deprecated
// because it is not safe to use
protected[tasks] var queue =
new PriorityBlockingQueue[QueuedTask](11, AppConstraintsOrdering.reverse)
protected[tasks] var apps = TrieMap.empty[(PathId, Timestamp), QueuedTask]

def list: Seq[QueuedTask] = queue.asScala.to[scala.collection.immutable.Seq]
def list: Seq[QueuedTask] = apps.values.to[Seq]

def listApps: Seq[AppDefinition] = list.map(_.app)

def poll(): Option[QueuedTask] = Option(queue.poll())

def add(app: AppDefinition): Unit =
queue.add(QueuedTask(app, rateLimiter.getDelay(app)))
def poll(): Option[QueuedTask] =
apps.values.toSeq.sortWith {
case (a, b) =>
a.app.constraints.size > b.app.constraints.size
}.find {
case QueuedTask(_, count, _) => count.decrementAndGet() >= 0
}

def add(app: AppDefinition): Unit = add(app, 1)

def add(app: AppDefinition, count: Int): Unit = {
val queuedTask = apps.getOrElseUpdate(
(app.id, app.version),
QueuedTask(app, new AtomicInteger(0), rateLimiter.getDelay(app)))
queuedTask.count.addAndGet(count)
}

/**
* Number of tasks in the queue for the given app
*
* @param app The app
* @return count
*/
def count(app: AppDefinition): Int = queue.asScala.count(_.app.id == app.id)
def count(app: AppDefinition): Int = apps.get((app.id, app.version)).map(_.count.get()).getOrElse(0)

def purge(appId: PathId): Unit = {
val retained = queue.asScala.filterNot(_.app.id == appId)
removeAll()
queue.addAll(retained.asJavaCollection)
for {
QueuedTask(app, _, _) <- apps.values
if app.id == appId
} apps.remove(app.id -> app.version)
}

/**
* Retains only elements that satisfy the supplied predicate.
*/
def retain(f: (QueuedTask => Boolean)): Unit =
queue.iterator.asScala.foreach { qt => if (!f(qt)) queue.remove(qt) }

def addAll(xs: Seq[QueuedTask]): Unit = queue.addAll(xs.asJavaCollection)

def removeAll(): Seq[QueuedTask] = {
val builder = new java.util.ArrayList[QueuedTask]()
queue.drainTo(builder)
builder.asScala.to[Seq]
apps.values.foreach {
case qt @ QueuedTask(app, _, _) => if (!f(qt)) apps.remove(app.id -> app.version)
}

def pollMatching[B](f: AppDefinition => Option[B]): Option[B] = {
val sorted = apps.values.toList.sortWith { (a, b) =>
a.app.constraints.size > b.app.constraints.size
}

@tailrec
def findMatching(xs: List[QueuedTask]): Option[B] = xs match {
case Nil => None
case head :: tail => head match {
case QueuedTask(app, _, delay) if delay.hasTimeLeft() =>
log.info(s"Delaying ${app.id} due to backoff. Time left: ${delay.timeLeft}.")
findMatching(tail)

case QueuedTask(app, count, delay) =>
val res = f(app)
if (res.isDefined && count.decrementAndGet() >= 0) {
res
}
else {
// app count is 0, so we can remove this app from the queue
apps.remove(app.id -> app.version)
findMatching(tail)
}
}
}

findMatching(sorted)
}

}

object TaskQueue {

protected[marathon] case class QueuedTask(app: AppDefinition, delay: Deadline)
protected[marathon] case class QueuedTask(app: AppDefinition, count: AtomicInteger, delay: Deadline)

protected object AppConstraintsOrdering extends Ordering[QueuedTask] {
def compare(t1: QueuedTask, t2: QueuedTask): Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ trait StartingBehavior { this: Actor with ActorLogging =>
val actualSize = taskQueue.count(app) + taskTracker.count(app.id)

if (actualSize < expectedSize) {
for (_ <- 0 until (expectedSize - actualSize)) taskQueue.add(app)
taskQueue.add(app, expectedSize - actualSize)
}
context.system.scheduler.scheduleOnce(5.seconds, self, Sync)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TaskReplaceActor(
driver.killTask(taskId)
}

for (_ <- 0 until app.instances) taskQueue.add(app)
taskQueue.add(app, app.instances)
}

override def postStop(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TaskStartActor(
var running: Int = 0

override def initializeStart(): Unit = {
for (_ <- 0 until nrToStart) taskQueue.add(app)
taskQueue.add(app, nrToStart)
}

override def postStop(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class MarathonSchedulerActorTest extends TestKit(ActorSystem("System"))

awaitAssert({
verify(tracker).shutdown("nope".toPath)
verify(queue).add(app)
verify(queue).add(app, 1)
verify(driver).killTask(TaskID("task_a"))
}, 5.seconds, 10.millis)
}
Expand All @@ -168,7 +168,7 @@ class MarathonSchedulerActorTest extends TestKit(ActorSystem("System"))
schedulerActor ! ScaleApp("test-app".toPath)

awaitAssert({
verify(queue).add(app)
verify(queue).add(app, 1)
}, 5.seconds, 10.millis)

expectMsg(5.seconds, AppScaled(app.id))
Expand Down
13 changes: 6 additions & 7 deletions src/test/scala/mesosphere/marathon/MarathonSchedulerTest.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package mesosphere.marathon

import java.util.concurrent.atomic.AtomicInteger

import akka.actor.ActorSystem
import akka.event.EventStream
import akka.testkit.{ TestKit, TestProbe }
Expand Down Expand Up @@ -45,7 +47,7 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon
repo = mock[AppRepository]
hcManager = mock[HealthCheckManager]
tracker = mock[TaskTracker]
queue = mock[TaskQueue]
queue = spy(new TaskQueue)
frameworkIdUtil = mock[FrameworkIdUtil]
config = defaultConfig()
taskIdUtil = mock[TaskIdUtil]
Expand Down Expand Up @@ -81,17 +83,15 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon
ports = Seq(8080),
version = now
)
val queuedTask = QueuedTask(app, Deadline.now)
val queuedTask = QueuedTask(app, new AtomicInteger(app.instances), Deadline.now)
val list = Vector(queuedTask)
val allApps = Vector(app)

queue.add(app)

when(taskIdUtil.newTaskId("testOffers".toRootPath))
.thenReturn(TaskID.newBuilder.setValue("testOffers_0-1234").build)
when(tracker.checkStagedTasks).thenReturn(Seq())
when(queue.poll()).thenReturn(Some(queuedTask))
when(queue.list).thenReturn(list)
when(queue.removeAll()).thenReturn(list)
when(queue.listApps).thenReturn(allApps)
when(repo.currentAppVersions())
.thenReturn(Future.successful(Map(app.id -> app.version)))

Expand All @@ -103,7 +103,6 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon

verify(driver).launchTasks(offersCaptor.capture(), taskInfosCaptor.capture())
verify(tracker).created(same(app.id), marathonTaskCaptor.capture())
verify(queue).addAll(Seq.empty)

assert(1 == offersCaptor.getValue.size())
assert(offer.getId == offersCaptor.getValue.get(0))
Expand Down
41 changes: 21 additions & 20 deletions src/test/scala/mesosphere/marathon/tasks/TaskQueueTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import mesosphere.marathon.state.AppDefinition
import mesosphere.marathon.state.PathId.StringPathId
import mesosphere.marathon.tasks.TaskQueue.QueuedTask

import scala.collection.immutable.Seq
import scala.concurrent.duration.Deadline

class TaskQueueTest extends MarathonSpec {
val app1 = AppDefinition(id = "app1".toPath, constraints = Set.empty)
val app2 = AppDefinition(id = "app2".toPath, constraints = Set(buildConstraint("hostname", "UNIQUE"), buildConstraint("rack_id", "CLUSTER", "rack-1")))
Expand Down Expand Up @@ -46,33 +43,37 @@ class TaskQueueTest extends MarathonSpec {
queue.add(app3)

assert(queue.list.size == 3, "Queue should contain 3 elements.")
queue.retain { case QueuedTask(app, _) => app.id == app2.id }
queue.retain { case QueuedTask(app, _, _) => app.id == app2.id }
assert(queue.list.size == 1, "Queue should contain 1 elements.")
}

test("RemoveAll") {
test("pollMatching") {
queue.add(app1)
queue.add(app2)
queue.add(app3)

val res = queue.removeAll().map(_.app)

assert(Vector(app2, app3, app1) == res, s"Should return all elements in correct order.")
assert(queue.queue.isEmpty, "TaskQueue should be empty.")
assert(Some(app1) == queue.pollMatching {
case x if x.id == "app1".toPath => Some(x)
case _ => None
})
}

test("AddAll") {
val queue = new TaskQueue
test("pollMatching Priority") {
queue.add(app1)
queue.add(app2)
queue.add(app3)

queue.addAll(Seq(
QueuedTask(app1, Deadline.now),
QueuedTask(app2, Deadline.now),
QueuedTask(app3, Deadline.now)
))
assert(Some(app2) == queue.pollMatching(Some(_)))
}

assert(queue.list.size == 3, "Queue should contain 3 elements.")
assert(queue.count(app1) == 1, s"Queue should contain $app1.")
assert(queue.count(app2) == 1, s"Queue should contain $app2.")
assert(queue.count(app3) == 1, s"Queue should contain $app3.")
test("pollMatching no match") {
queue.add(app1)
queue.add(app2)
queue.add(app3)

assert(None == queue.pollMatching {
case x if x.id == "DOES_NOT_EXIST".toPath => Some(x)
case _ => None
})
}
}
Loading

0 comments on commit 646e0fd

Please sign in to comment.