diff --git a/src/main/scala/mesosphere/marathon/MarathonScheduler.scala b/src/main/scala/mesosphere/marathon/MarathonScheduler.scala index e1785dc1e78..01986e9f921 100644 --- a/src/main/scala/mesosphere/marathon/MarathonScheduler.scala +++ b/src/main/scala/mesosphere/marathon/MarathonScheduler.scala @@ -1,5 +1,6 @@ package mesosphere.marathon +import java.util import javax.inject.{ Inject, Named } import akka.actor.{ ActorRef, ActorSystem } @@ -97,7 +98,7 @@ class MarathonScheduler @Inject() ( Await.result(appRepo.currentAppVersions(), config.zkTimeoutDuration) taskQueue.retain { - case QueuedTask(app, _) => + case QueuedTask(app, _, _) => appVersions.get(app.id) contains app.version } @@ -105,44 +106,26 @@ class MarathonScheduler @Inject() ( try { log.debug("Received offer %s".format(offer)) - val queuedTasks: Seq[QueuedTask] = taskQueue.removeAll() - - val withTaskInfos: collection.Seq[(QueuedTask, (TaskInfo, Seq[Long]))] = - queuedTasks.view.flatMap { case qt => newTask(qt.app, offer).map(qt -> _) } - - val launchedTask = withTaskInfos.find { - case (qt, (taskInfo, ports)) => - val timeLeft = qt.delay.timeLeft - if (timeLeft.toNanos <= 0) { - true - } - else { - log.info(s"Delaying task ${taskInfo.getTaskId.getValue} due to backoff. Time left: $timeLeft.") - false - } + val matchingTask = taskQueue.pollMatching { app => + newTask(app, offer).map(app -> _) } - launchedTask.foreach { - case (qt, (taskInfo, ports)) => - val taskInfos = Seq(taskInfo) - log.debug("Launching tasks: " + taskInfos) - + matchingTask.foreach { + case (app, (taskInfo, ports)) => val marathonTask = MarathonTasks.makeTask( taskInfo.getTaskId.getValue, offer.getHostname, ports, - offer.getAttributesList.asScala, qt.app.version) + offer.getAttributesList.asScala, app.version) + + log.debug("Launching task: " + taskInfo) - taskTracker.created(qt.app.id, marathonTask) - driver.launchTasks(Seq(offer.getId).asJava, taskInfos.asJava) + taskTracker.created(app.id, marathonTask) + driver.launchTasks(Seq(offer.getId).asJava, util.Arrays.asList(taskInfo)) // here it is assumed that the health checks for the current // version are already running. } - // put unscheduled tasks back in the queue - val launchedTaskSeq: Seq[QueuedTask] = launchedTask.map(_._1).to[Seq] - taskQueue.addAll(queuedTasks diff launchedTaskSeq) - - if (launchedTask.isEmpty) { + if (matchingTask.isEmpty) { log.debug("Offer doesn't match request. Declining.") driver.declineOffer(offer.getId) } @@ -289,7 +272,6 @@ class MarathonScheduler @Inject() ( private def newTask( app: AppDefinition, offer: Offer): Option[(TaskInfo, Seq[Long])] = { - // TODO this should return a MarathonTask new TaskBuilder(app, taskIdUtil.newTaskId, taskTracker, config, mapper).buildIfMatches(offer) } } diff --git a/src/main/scala/mesosphere/marathon/MarathonSchedulerActor.scala b/src/main/scala/mesosphere/marathon/MarathonSchedulerActor.scala index cec03fbf568..6ed6c5ef760 100644 --- a/src/main/scala/mesosphere/marathon/MarathonSchedulerActor.scala +++ b/src/main/scala/mesosphere/marathon/MarathonSchedulerActor.scala @@ -483,8 +483,7 @@ class SchedulerActions( if (toQueue > 0) { log.info(s"Queueing $toQueue new tasks for ${app.id} ($queuedCount queued)") - for (i <- 0 until toQueue) - taskQueue.add(app) + taskQueue.add(app, toQueue) } else { log.info(s"Already queued $queuedCount tasks for ${app.id}. Not scaling.") diff --git a/src/main/scala/mesosphere/marathon/tasks/TaskQueue.scala b/src/main/scala/mesosphere/marathon/tasks/TaskQueue.scala index aba37b232e7..06fe266eb21 100644 --- a/src/main/scala/mesosphere/marathon/tasks/TaskQueue.scala +++ b/src/main/scala/mesosphere/marathon/tasks/TaskQueue.scala @@ -1,14 +1,15 @@ package mesosphere.marathon.tasks -import java.util.concurrent.PriorityBlockingQueue +import java.util.concurrent.atomic.AtomicInteger -import mesosphere.marathon.state.{ AppDefinition, PathId } +import mesosphere.marathon.state.{ Timestamp, AppDefinition, PathId } import mesosphere.util.RateLimiter +import org.apache.log4j.Logger +import scala.annotation.tailrec +import scala.collection.concurrent.TrieMap import scala.concurrent.duration.Deadline -import scala.collection.mutable import scala.collection.immutable.Seq -import scala.collection.JavaConverters._ /** * Utility class to stage tasks before they get scheduled @@ -17,21 +18,31 @@ class TaskQueue { import mesosphere.marathon.tasks.TaskQueue._ + private val log = Logger.getLogger(getClass) protected[marathon] val rateLimiter = new RateLimiter - // we used SynchronizedPriorityQueue before, but it has been deprecated - // because it is not safe to use - protected[tasks] var queue = - new PriorityBlockingQueue[QueuedTask](11, AppConstraintsOrdering.reverse) + protected[tasks] var apps = TrieMap.empty[(PathId, Timestamp), QueuedTask] - def list: Seq[QueuedTask] = queue.asScala.to[scala.collection.immutable.Seq] + def list: Seq[QueuedTask] = apps.values.to[Seq] def listApps: Seq[AppDefinition] = list.map(_.app) - def poll(): Option[QueuedTask] = Option(queue.poll()) - - def add(app: AppDefinition): Unit = - queue.add(QueuedTask(app, rateLimiter.getDelay(app))) + def poll(): Option[QueuedTask] = + apps.values.toSeq.sortWith { + case (a, b) => + a.app.constraints.size > b.app.constraints.size + }.find { + case QueuedTask(_, count, _) => count.decrementAndGet() >= 0 + } + + def add(app: AppDefinition): Unit = add(app, 1) + + def add(app: AppDefinition, count: Int): Unit = { + val queuedTask = apps.getOrElseUpdate( + (app.id, app.version), + QueuedTask(app, new AtomicInteger(0), rateLimiter.getDelay(app))) + queuedTask.count.addAndGet(count) + } /** * Number of tasks in the queue for the given app @@ -39,33 +50,56 @@ class TaskQueue { * @param app The app * @return count */ - def count(app: AppDefinition): Int = queue.asScala.count(_.app.id == app.id) + def count(app: AppDefinition): Int = apps.get((app.id, app.version)).map(_.count.get()).getOrElse(0) def purge(appId: PathId): Unit = { - val retained = queue.asScala.filterNot(_.app.id == appId) - removeAll() - queue.addAll(retained.asJavaCollection) + for { + QueuedTask(app, _, _) <- apps.values + if app.id == appId + } apps.remove(app.id -> app.version) } /** * Retains only elements that satisfy the supplied predicate. */ def retain(f: (QueuedTask => Boolean)): Unit = - queue.iterator.asScala.foreach { qt => if (!f(qt)) queue.remove(qt) } - - def addAll(xs: Seq[QueuedTask]): Unit = queue.addAll(xs.asJavaCollection) - - def removeAll(): Seq[QueuedTask] = { - val builder = new java.util.ArrayList[QueuedTask]() - queue.drainTo(builder) - builder.asScala.to[Seq] + apps.values.foreach { + case qt @ QueuedTask(app, _, _) => if (!f(qt)) apps.remove(app.id -> app.version) + } + + def pollMatching[B](f: AppDefinition => Option[B]): Option[B] = { + val sorted = apps.values.toList.sortWith { (a, b) => + a.app.constraints.size > b.app.constraints.size + } + + @tailrec + def findMatching(xs: List[QueuedTask]): Option[B] = xs match { + case Nil => None + case head :: tail => head match { + case QueuedTask(app, _, delay) if delay.hasTimeLeft() => + log.info(s"Delaying ${app.id} due to backoff. Time left: ${delay.timeLeft}.") + findMatching(tail) + + case QueuedTask(app, count, delay) => + val res = f(app) + if (res.isDefined && count.decrementAndGet() >= 0) { + res + } + else { + // app count is 0, so we can remove this app from the queue + apps.remove(app.id -> app.version) + findMatching(tail) + } + } + } + + findMatching(sorted) } - } object TaskQueue { - protected[marathon] case class QueuedTask(app: AppDefinition, delay: Deadline) + protected[marathon] case class QueuedTask(app: AppDefinition, count: AtomicInteger, delay: Deadline) protected object AppConstraintsOrdering extends Ordering[QueuedTask] { def compare(t1: QueuedTask, t2: QueuedTask): Int = diff --git a/src/main/scala/mesosphere/marathon/upgrade/StartingBehavior.scala b/src/main/scala/mesosphere/marathon/upgrade/StartingBehavior.scala index 34db1419ba0..76fb7954a5d 100644 --- a/src/main/scala/mesosphere/marathon/upgrade/StartingBehavior.scala +++ b/src/main/scala/mesosphere/marathon/upgrade/StartingBehavior.scala @@ -75,7 +75,7 @@ trait StartingBehavior { this: Actor with ActorLogging => val actualSize = taskQueue.count(app) + taskTracker.count(app.id) if (actualSize < expectedSize) { - for (_ <- 0 until (expectedSize - actualSize)) taskQueue.add(app) + taskQueue.add(app, expectedSize - actualSize) } context.system.scheduler.scheduleOnce(5.seconds, self, Sync) } diff --git a/src/main/scala/mesosphere/marathon/upgrade/TaskReplaceActor.scala b/src/main/scala/mesosphere/marathon/upgrade/TaskReplaceActor.scala index 98e080f7b9a..b5fea4a58df 100644 --- a/src/main/scala/mesosphere/marathon/upgrade/TaskReplaceActor.scala +++ b/src/main/scala/mesosphere/marathon/upgrade/TaskReplaceActor.scala @@ -44,7 +44,7 @@ class TaskReplaceActor( driver.killTask(taskId) } - for (_ <- 0 until app.instances) taskQueue.add(app) + taskQueue.add(app, app.instances) } override def postStop(): Unit = { diff --git a/src/main/scala/mesosphere/marathon/upgrade/TaskStartActor.scala b/src/main/scala/mesosphere/marathon/upgrade/TaskStartActor.scala index 515a94245b8..6a27c0b2ad9 100644 --- a/src/main/scala/mesosphere/marathon/upgrade/TaskStartActor.scala +++ b/src/main/scala/mesosphere/marathon/upgrade/TaskStartActor.scala @@ -23,7 +23,7 @@ class TaskStartActor( var running: Int = 0 override def initializeStart(): Unit = { - for (_ <- 0 until nrToStart) taskQueue.add(app) + taskQueue.add(app, nrToStart) } override def postStop(): Unit = { diff --git a/src/test/scala/mesosphere/marathon/MarathonSchedulerActorTest.scala b/src/test/scala/mesosphere/marathon/MarathonSchedulerActorTest.scala index f12936bf70e..6d5a83529eb 100644 --- a/src/test/scala/mesosphere/marathon/MarathonSchedulerActorTest.scala +++ b/src/test/scala/mesosphere/marathon/MarathonSchedulerActorTest.scala @@ -144,7 +144,7 @@ class MarathonSchedulerActorTest extends TestKit(ActorSystem("System")) awaitAssert({ verify(tracker).shutdown("nope".toPath) - verify(queue).add(app) + verify(queue).add(app, 1) verify(driver).killTask(TaskID("task_a")) }, 5.seconds, 10.millis) } @@ -168,7 +168,7 @@ class MarathonSchedulerActorTest extends TestKit(ActorSystem("System")) schedulerActor ! ScaleApp("test-app".toPath) awaitAssert({ - verify(queue).add(app) + verify(queue).add(app, 1) }, 5.seconds, 10.millis) expectMsg(5.seconds, AppScaled(app.id)) diff --git a/src/test/scala/mesosphere/marathon/MarathonSchedulerTest.scala b/src/test/scala/mesosphere/marathon/MarathonSchedulerTest.scala index dc8149e70df..54379a3c28f 100644 --- a/src/test/scala/mesosphere/marathon/MarathonSchedulerTest.scala +++ b/src/test/scala/mesosphere/marathon/MarathonSchedulerTest.scala @@ -1,5 +1,7 @@ package mesosphere.marathon +import java.util.concurrent.atomic.AtomicInteger + import akka.actor.ActorSystem import akka.event.EventStream import akka.testkit.{ TestKit, TestProbe } @@ -45,7 +47,7 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon repo = mock[AppRepository] hcManager = mock[HealthCheckManager] tracker = mock[TaskTracker] - queue = mock[TaskQueue] + queue = spy(new TaskQueue) frameworkIdUtil = mock[FrameworkIdUtil] config = defaultConfig() taskIdUtil = mock[TaskIdUtil] @@ -81,17 +83,15 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon ports = Seq(8080), version = now ) - val queuedTask = QueuedTask(app, Deadline.now) + val queuedTask = QueuedTask(app, new AtomicInteger(app.instances), Deadline.now) val list = Vector(queuedTask) val allApps = Vector(app) + queue.add(app) + when(taskIdUtil.newTaskId("testOffers".toRootPath)) .thenReturn(TaskID.newBuilder.setValue("testOffers_0-1234").build) when(tracker.checkStagedTasks).thenReturn(Seq()) - when(queue.poll()).thenReturn(Some(queuedTask)) - when(queue.list).thenReturn(list) - when(queue.removeAll()).thenReturn(list) - when(queue.listApps).thenReturn(allApps) when(repo.currentAppVersions()) .thenReturn(Future.successful(Map(app.id -> app.version))) @@ -103,7 +103,6 @@ class MarathonSchedulerTest extends TestKit(ActorSystem("System")) with Marathon verify(driver).launchTasks(offersCaptor.capture(), taskInfosCaptor.capture()) verify(tracker).created(same(app.id), marathonTaskCaptor.capture()) - verify(queue).addAll(Seq.empty) assert(1 == offersCaptor.getValue.size()) assert(offer.getId == offersCaptor.getValue.get(0)) diff --git a/src/test/scala/mesosphere/marathon/tasks/TaskQueueTest.scala b/src/test/scala/mesosphere/marathon/tasks/TaskQueueTest.scala index 5954f38bb3c..47b234461a3 100644 --- a/src/test/scala/mesosphere/marathon/tasks/TaskQueueTest.scala +++ b/src/test/scala/mesosphere/marathon/tasks/TaskQueueTest.scala @@ -7,9 +7,6 @@ import mesosphere.marathon.state.AppDefinition import mesosphere.marathon.state.PathId.StringPathId import mesosphere.marathon.tasks.TaskQueue.QueuedTask -import scala.collection.immutable.Seq -import scala.concurrent.duration.Deadline - class TaskQueueTest extends MarathonSpec { val app1 = AppDefinition(id = "app1".toPath, constraints = Set.empty) val app2 = AppDefinition(id = "app2".toPath, constraints = Set(buildConstraint("hostname", "UNIQUE"), buildConstraint("rack_id", "CLUSTER", "rack-1"))) @@ -46,33 +43,37 @@ class TaskQueueTest extends MarathonSpec { queue.add(app3) assert(queue.list.size == 3, "Queue should contain 3 elements.") - queue.retain { case QueuedTask(app, _) => app.id == app2.id } + queue.retain { case QueuedTask(app, _, _) => app.id == app2.id } assert(queue.list.size == 1, "Queue should contain 1 elements.") } - test("RemoveAll") { + test("pollMatching") { queue.add(app1) queue.add(app2) queue.add(app3) - val res = queue.removeAll().map(_.app) - - assert(Vector(app2, app3, app1) == res, s"Should return all elements in correct order.") - assert(queue.queue.isEmpty, "TaskQueue should be empty.") + assert(Some(app1) == queue.pollMatching { + case x if x.id == "app1".toPath => Some(x) + case _ => None + }) } - test("AddAll") { - val queue = new TaskQueue + test("pollMatching Priority") { + queue.add(app1) + queue.add(app2) + queue.add(app3) - queue.addAll(Seq( - QueuedTask(app1, Deadline.now), - QueuedTask(app2, Deadline.now), - QueuedTask(app3, Deadline.now) - )) + assert(Some(app2) == queue.pollMatching(Some(_))) + } - assert(queue.list.size == 3, "Queue should contain 3 elements.") - assert(queue.count(app1) == 1, s"Queue should contain $app1.") - assert(queue.count(app2) == 1, s"Queue should contain $app2.") - assert(queue.count(app3) == 1, s"Queue should contain $app3.") + test("pollMatching no match") { + queue.add(app1) + queue.add(app2) + queue.add(app3) + + assert(None == queue.pollMatching { + case x if x.id == "DOES_NOT_EXIST".toPath => Some(x) + case _ => None + }) } } diff --git a/src/test/scala/mesosphere/marathon/upgrade/DeploymentActorTest.scala b/src/test/scala/mesosphere/marathon/upgrade/DeploymentActorTest.scala index d8caea3d300..d0393835e7d 100644 --- a/src/test/scala/mesosphere/marathon/upgrade/DeploymentActorTest.scala +++ b/src/test/scala/mesosphere/marathon/upgrade/DeploymentActorTest.scala @@ -18,7 +18,7 @@ import mesosphere.mesos.protos.Implicits._ import mesosphere.mesos.protos.TaskID import org.apache.mesos.Protos.Status import org.apache.mesos.SchedulerDriver -import org.mockito.Matchers.any +import org.mockito.Matchers.{ any, same } import org.mockito.Mockito.{ times, verify, when } import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -100,9 +100,11 @@ class DeploymentActorTest } }) - when(queue.add(app2New)).thenAnswer(new Answer[Boolean] { + when(queue.add(same(app2New), any[Int])).thenAnswer(new Answer[Boolean] { def answer(invocation: InvocationOnMock): Boolean = { - system.eventStream.publish(MesosStatusUpdateEvent("", UUID.randomUUID().toString, "TASK_RUNNING", "", app2.id, "", Nil, app2New.version.toString)) + println(invocation.getArguments.toSeq) + for (i <- 0 until invocation.getArguments()(1).asInstanceOf[Int]) + system.eventStream.publish(MesosStatusUpdateEvent("", UUID.randomUUID().toString, "TASK_RUNNING", "", app2.id, "", Nil, app2New.version.toString)) true } }) @@ -194,17 +196,15 @@ class DeploymentActorTest }) val taskIDs = Iterator.from(3) - var taskCount = 0 when(queue.count(appNew)).thenAnswer(new Answer[Int] { - override def answer(p1: InvocationOnMock): Int = taskCount + override def answer(p1: InvocationOnMock): Int = appNew.instances }) - when(queue.add(appNew)).thenAnswer(new Answer[Boolean] { + when(queue.add(same(appNew), any[Int])).thenAnswer(new Answer[Boolean] { def answer(invocation: InvocationOnMock): Boolean = { - if (taskCount >= 2) throw new Exception("Too many invocations.") - taskCount += 1 - system.eventStream.publish(MesosStatusUpdateEvent("", s"task1_${taskIDs.next()}", "TASK_RUNNING", "", app.id, "", Nil, appNew.version.toString)) + for (i <- 0 until invocation.getArguments()(1).asInstanceOf[Int]) + system.eventStream.publish(MesosStatusUpdateEvent("", s"task1_${taskIDs.next()}", "TASK_RUNNING", "", app.id, "", Nil, appNew.version.toString)) true } }) @@ -233,7 +233,7 @@ class DeploymentActorTest verify(driver).killTask(TaskID(task1_1.getId)) verify(driver).killTask(TaskID(task1_2.getId)) - verify(queue, times(2)).add(appNew) + verify(queue).add(appNew, 2) } finally { system.shutdown() diff --git a/src/test/scala/mesosphere/marathon/upgrade/TaskStartActorTest.scala b/src/test/scala/mesosphere/marathon/upgrade/TaskStartActorTest.scala index 7d377a77948..eea4b96140e 100644 --- a/src/test/scala/mesosphere/marathon/upgrade/TaskStartActorTest.scala +++ b/src/test/scala/mesosphere/marathon/upgrade/TaskStartActorTest.scala @@ -3,14 +3,14 @@ package mesosphere.marathon.upgrade import akka.actor.{ ActorSystem, Props } import akka.testkit.{ TestActorRef, TestKit } import com.codahale.metrics.MetricRegistry -import mesosphere.marathon.{ MarathonConf, SchedulerActions, TaskUpgradeCanceledException } import mesosphere.marathon.event.{ HealthStatusChanged, MesosStatusUpdateEvent } import mesosphere.marathon.state.AppDefinition import mesosphere.marathon.state.PathId._ -import mesosphere.marathon.tasks.{ TaskTracker, TaskQueue } +import mesosphere.marathon.tasks.{ TaskQueue, TaskTracker } +import mesosphere.marathon.{ MarathonConf, SchedulerActions, TaskUpgradeCanceledException } import org.apache.mesos.SchedulerDriver import org.apache.mesos.state.InMemoryState -import org.mockito.Mockito.{ spy, times, verify } +import org.mockito.Mockito.{ times, spy, verify } import org.scalatest.mock.MockitoSugar import org.scalatest.{ BeforeAndAfterAll, FunSuiteLike, Matchers } @@ -54,7 +54,7 @@ class TaskStartActorTest awaitCond(taskQueue.count(app) == 5, 3.seconds) - for ((task, i) <- taskQueue.removeAll().zipWithIndex) + for (i <- 0 until taskQueue.count(app)) system.eventStream.publish(MesosStatusUpdateEvent("", s"task-$i", "TASK_RUNNING", "", app.id, "", Nil, app.version.toString)) Await.result(promise.future, 3.seconds) should be(()) @@ -115,8 +115,8 @@ class TaskStartActorTest awaitCond(taskQueue.count(app) == 5, 3.seconds) - for ((_, i) <- taskQueue.removeAll().zipWithIndex) - system.eventStream.publish(HealthStatusChanged(app.id, s"task_${i}", app.version.toString, true)) + for (i <- 0 until taskQueue.count(app)) + system.eventStream.publish(HealthStatusChanged(app.id, s"task_$i", app.version.toString, alive = true)) Await.result(promise.future, 3.seconds) should be(()) @@ -208,14 +208,15 @@ class TaskStartActorTest awaitCond(taskQueue.count(app) == 1, 3.seconds) - for (task <- taskQueue.removeAll()) - system.eventStream.publish(MesosStatusUpdateEvent("", "", "TASK_FAILED", "", app.id, "", Nil, app.version.toString)) + taskQueue.purge(app.id) + + system.eventStream.publish(MesosStatusUpdateEvent("", "", "TASK_FAILED", "", app.id, "", Nil, app.version.toString)) awaitCond(taskQueue.count(app) == 1, 3.seconds) - verify(taskQueue, times(2)).add(app) + verify(taskQueue, times(2)).add(app, 1) - for (task <- taskQueue.removeAll()) + for (i <- 0 until taskQueue.count(app)) system.eventStream.publish(MesosStatusUpdateEvent("", "", "TASK_RUNNING", "", app.id, "", Nil, app.version.toString)) Await.result(promise.future, 3.seconds) should be(())