diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 760e458972d98..002d7b6eaf498 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -67,6 +67,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -359,7 +360,9 @@ private[spark] class ApplicationMaster( } logDebug(s"Number of pending allocations is $numPendingAllocate. " + s"Sleeping for $sleepInterval.") - Thread.sleep(sleepInterval) + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -546,8 +549,15 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutors(requestedTotal)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 21193e7c625e3..940873fbd046c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -146,11 +146,16 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. + * + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + true + } else { + false } }