Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-11882: Custom scheduler support #10292

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.implicitConversions

import java.io._
import java.lang.reflect.Constructor
import java.net.URI
import java.net.{URL, URI}
import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger}
import java.util.UUID.randomUUID
Expand All @@ -31,6 +31,7 @@ import scala.collection.{Map, Set}
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.commons.lang.SerializationUtils
Expand Down Expand Up @@ -2732,6 +2733,20 @@ object SparkContext extends Logging {
scheduler.initialize(backend)
(backend, scheduler)

case uri @ SchedulerFactory(name)
if SchedulerFactory.getSchedulerFactoryClassName(sc.conf.getAll, name).isDefined =>

val className = SchedulerFactory.getSchedulerFactoryClassName(sc.conf.getAll, name).get
val clazz = Utils.classForName(className)
val factory = clazz.newInstance().asInstanceOf[SchedulerFactory]
val scheduler = factory.createScheduler(sc)
val backend = factory.createSchedulerBackend(scheduler, sc, new URI(uri))
scheduler match {
case ts: TaskSchedulerImpl => ts.initialize(backend)
case _ =>
}
(backend, scheduler)

case zkUrl if zkUrl.startsWith("zk://") =>
logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " +
"in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.")
Expand Down
47 changes: 36 additions & 11 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab
import java.net.URL
import java.security.PrivilegedExceptionAction

import org.apache.spark.scheduler.SchedulerFactory

import scala.collection.mutable.{ArrayBuffer, HashMap, Map}

import org.apache.commons.lang3.StringUtils
Expand Down Expand Up @@ -67,7 +69,8 @@ object SparkSubmit {
private val STANDALONE = 2
private val MESOS = 4
private val LOCAL = 8
private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL
private val CUSTOM = 16
private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | CUSTOM

// Deploy modes
private val CLIENT = 1
Expand Down Expand Up @@ -230,7 +233,13 @@ object SparkSubmit {
case m if m.startsWith("spark") => STANDALONE
case m if m.startsWith("mesos") => MESOS
case m if m.startsWith("local") => LOCAL
case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1
case m @ SchedulerFactory(name) if SchedulerFactory.getSchedulerFactoryClassName(
args.sparkProperties, name).isDefined =>
childMainClass =
SchedulerFactory.getSchedulerClientClassName(args.sparkProperties, name).getOrElse("")
CUSTOM
case _ => printErrorAndExit("Master must start with yarn, spark, mesos, local or " +
"with a name defined at spark.scheduler.factory.<name> in configuration"); -1
}

// Set the deploy mode; default is client mode
Expand Down Expand Up @@ -470,22 +479,22 @@ object SparkSubmit {
OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),

// Other options
OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES,
OptionAssigner(args.executorCores, STANDALONE | CUSTOM | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.cores"),
OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES,
OptionAssigner(args.executorMemory, STANDALONE | CUSTOM | MESOS | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.memory"),
OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES,
OptionAssigner(args.totalExecutorCores, STANDALONE | CUSTOM | MESOS, ALL_DEPLOY_MODES,
sysProp = "spark.cores.max"),
OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES,
OptionAssigner(args.files, LOCAL | STANDALONE | CUSTOM | MESOS, ALL_DEPLOY_MODES,
sysProp = "spark.files"),
OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"),
OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER,
OptionAssigner(args.jars, STANDALONE | CUSTOM | MESOS, CLUSTER, sysProp = "spark.jars"),
OptionAssigner(args.driverMemory, STANDALONE | CUSTOM | MESOS, CLUSTER,
sysProp = "spark.driver.memory"),
OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER,
OptionAssigner(args.driverCores, STANDALONE | CUSTOM | MESOS, CLUSTER,
sysProp = "spark.driver.cores"),
OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER,
OptionAssigner(args.supervise.toString, STANDALONE | CUSTOM | MESOS, CLUSTER,
sysProp = "spark.driver.supervise"),
OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy")
OptionAssigner(args.ivyRepoPath, STANDALONE | CUSTOM, CLUSTER, sysProp = "spark.jars.ivy")
)

// In client mode, launch the application main class directly
Expand Down Expand Up @@ -607,6 +616,22 @@ object SparkSubmit {
}
}

if (clusterManager == CUSTOM && deployMode == CLUSTER) {
if (childMainClass == "") throw new IllegalArgumentException(
"A custom scheduler is chosen but there is no client class defined for it. " +
"Try defining a client class at spark.scheduler.client.<name> in your configuration.")
Option(args.driverMemory).foreach { m => sysProps += "spark.driver.memory" -> m }
Option(args.driverCores).foreach { c => sysProps += "spark.driver.cores" -> c }
sysProps += "spark.driver.supervise" -> args.supervise.toString
sysProps += "spark.master" -> args.master
childArgs += "launch"
childArgs += args.primaryResource
childArgs += args.mainClass
if (args.childArgs != null) {
childArgs ++= args.childArgs
}
}

// Load any properties specified through --conf and the default properties file
for ((k, v) <- args.sparkProperties) {
sysProps.getOrElseUpdate(k, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull

// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
val uri = new URI(primaryResource)
Expand Down Expand Up @@ -255,7 +254,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
}

private def validateKillArguments(): Unit = {
if (!master.startsWith("spark://") && !master.startsWith("mesos://")) {
if (master.startsWith("yarn") || master.startsWith("local")) {
SparkSubmit.printErrorAndExit(
"Killing submissions is only supported in standalone or Mesos mode!")
}
Expand All @@ -265,7 +264,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
}

private def validateStatusRequestArguments(): Unit = {
if (!master.startsWith("spark://") && !master.startsWith("mesos://")) {
if (master.startsWith("yarn") || master.startsWith("local")) {
SparkSubmit.printErrorAndExit(
"Requesting submission statuses is only supported in standalone or Mesos mode!")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler

import java.net.URI

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend

import scala.util.Try

/**
* An interface to be implemented if a custom scheduler is to be used. A class name of the
* particular implementation has to be added to Spark configuration at
* `spark.scheduler.factory.<name>=<class-name>`. Name is the master URI scheme which
* will make SparkContext use the particular scheduler factory.
*/
trait SchedulerFactory {
/**
* The method creates TaskScheduler. Currently it just needs to create instance of
* [[TaskSchedulerImpl]].
*/
def createScheduler(sc: SparkContext): TaskScheduler

/**
* The method creates a custom scheduler backend. The custom backend must implement
* [[CoarseGrainedSchedulerBackend]].
*/
def createSchedulerBackend(
scheduler: TaskScheduler, sc: SparkContext, uri: URI): CoarseGrainedSchedulerBackend
}

private[spark] object SchedulerFactory {
private val schedulerFactoryPattern = """^spark\.scheduler\.factory\.(.+)$""".r
private val schedulerClientPattern = """^spark\.scheduler\.client\.(.+)$""".r

def getSchedulerFactoryClassName(
conf: Iterable[(String, String)],
schedulerName: String): Option[String] =
conf.collectFirst {
case (schedulerFactoryPattern(name), clazzName) if name.equalsIgnoreCase(schedulerName) =>
clazzName
}

def getSchedulerClientClassName(
conf: Iterable[(String, String)],
schedulerName: String): Option[String] =
conf.collectFirst {
case (schedulerClientPattern(name), clazzName) if name.equalsIgnoreCase(schedulerName) =>
clazzName
}

def unapply(masterUri: String): Option[String] = {
for (uri <- Try(new URI(masterUri)).toOption) yield uri.getScheme
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,33 @@

package org.apache.spark

import java.net.URI

import org.apache.spark.rpc.RpcEnv
import org.scalatest.PrivateMethodTester

import org.apache.spark.util.Utils
import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
import org.apache.spark.scheduler.{SchedulerFactory, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SimrSchedulerBackend,
SparkDeploySchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend

class CustomSchedulerFactory extends SchedulerFactory {
override def createScheduler(sc: SparkContext): TaskScheduler = {
new TaskSchedulerImpl(sc)
}

override def createSchedulerBackend(
scheduler: TaskScheduler, sc: SparkContext, uri: URI): CoarseGrainedSchedulerBackend = {
new CustomScheduler(scheduler.asInstanceOf[TaskSchedulerImpl], sc, Array(uri.toString))
}
}

class CustomScheduler(scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String])
extends CoarseGrainedSchedulerBackend(scheduler, RpcEnv.create("x", "localhost", 0,
new SparkConf(), new SecurityManager(new SparkConf())))

class SparkContextSchedulerCreationSuite
extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging {

Expand All @@ -48,6 +67,16 @@ class SparkContextSchedulerCreationSuite
assert(e.getMessage.contains("Could not parse Master URL"))
}

test("custom") {
val sched = createTaskScheduler("custom://1.2.3.4:100", new SparkConf()
.set("spark.scheduler.factory.custom", classOf[CustomSchedulerFactory].getCanonicalName))
sched.backend match {
case s: CustomScheduler =>
s.rpcEnv.shutdown()
case _ => fail()
}
}

test("local") {
val sched = createTaskScheduler("local")
sched.backend match {
Expand Down
56 changes: 56 additions & 0 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,40 @@ class SparkSubmitSuite
sysProps("spark.ui.enabled") should be ("false")
}

test("handles custom cluster mode") {
val clArgs = Seq(
"--deploy-mode", "cluster",
"--master", "custom://h:p",
"--class", "org.SomeClass",
"--supervise",
"--driver-memory", "4g",
"--driver-cores", "5",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.scheduler.factory.custom=some.custom.SchedulerFactory",
"--conf", "spark.scheduler.client.custom=some.custom.SchedulerClient",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
childArgsStr should include regex "launch .*thejar.jar org.SomeClass arg1 arg2"
mainClass should be ("some.custom.SchedulerClient")
classpath should have size 0
sysProps should have size 11
sysProps.keys should contain ("SPARK_SUBMIT")
sysProps.keys should contain ("spark.master")
sysProps.keys should contain ("spark.app.name")
sysProps.keys should contain ("spark.jars")
sysProps.keys should contain ("spark.driver.memory")
sysProps.keys should contain ("spark.driver.cores")
sysProps.keys should contain ("spark.driver.supervise")
sysProps.keys should contain ("spark.ui.enabled")
sysProps.keys should contain ("spark.submit.deployMode")
sysProps.keys should contain ("spark.scheduler.factory.custom")
sysProps.keys should contain ("spark.scheduler.client.custom")
sysProps("spark.ui.enabled") should be ("false")
}

test("handles standalone client mode") {
val clArgs = Seq(
"--deploy-mode", "client",
Expand All @@ -283,6 +317,28 @@ class SparkSubmitSuite
sysProps("spark.ui.enabled") should be ("false")
}

test("handles custom client mode") {
val clArgs = Seq(
"--deploy-mode", "client",
"--master", "custom://h:p",
"--executor-memory", "5g",
"--total-executor-cores", "5",
"--class", "org.SomeClass",
"--driver-memory", "4g",
"--conf", "spark.ui.enabled=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (1)
classpath(0) should endWith ("thejar.jar")
sysProps("spark.executor.memory") should be ("5g")
sysProps("spark.cores.max") should be ("5")
sysProps("spark.ui.enabled") should be ("false")
}

test("handles mesos client mode") {
val clArgs = Seq(
"--deploy-mode", "client",
Expand Down