Skip to content

Commit

Permalink
[SPARK-17874][CORE] Add SSL port configuration.
Browse files Browse the repository at this point in the history
Make the SSL port configuration explicit, instead of deriving it
from the non-SSL port, but retain the existing functionality in
case anyone depends on it.

The change starts the HTTPS and HTTP connectors separately, so
that it's possible to use independent ports for each. For that to
work, the initialization of the server needs to be shuffled around
a bit. The change also makes it so the initialization of both
connectors is similar, and end up using the same Scheduler - previously
only the HTTP connector would use the correct one.

Also fixed some outdated documentation about a couple of services
that were removed long ago.

Tested with unit tests and by running spark-shell with SSL configs.

Author: Marcelo Vanzin <[email protected]>

Closes #16625 from vanzin/SPARK-17874.
  • Loading branch information
Marcelo Vanzin authored and sarutak committed Feb 9, 2017
1 parent 1a09cd6 commit 3fc8e8c
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 89 deletions.
9 changes: 9 additions & 0 deletions core/src/main/scala/org/apache/spark/SSLOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging
*
* @param enabled enables or disables SSL; if it is set to false, the rest of the
* settings are disregarded
* @param port the port where to bind the SSL server; if not defined, it will be
* based on the non-SSL port for the same service.
* @param keyStore a path to the key-store file
* @param keyStorePassword a password to access the key-store file
* @param keyPassword a password to access the private key in the key-store
Expand All @@ -47,6 +49,7 @@ import org.apache.spark.internal.Logging
*/
private[spark] case class SSLOptions(
enabled: Boolean = false,
port: Option[Int] = None,
keyStore: Option[File] = None,
keyStorePassword: Option[String] = None,
keyPassword: Option[String] = None,
Expand Down Expand Up @@ -164,6 +167,11 @@ private[spark] object SSLOptions extends Logging {
def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = {
val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled))

val port = conf.getOption(s"$ns.port").map(_.toInt)
port.foreach { p =>
require(p >= 0, "Port number must be a non-negative value.")
}

val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_))
.orElse(defaults.flatMap(_.keyStore))

Expand Down Expand Up @@ -198,6 +206,7 @@ private[spark] object SSLOptions extends Logging {

new SSLOptions(
enabled,
port,
keyStore,
keyStorePassword,
keyPassword,
Expand Down
187 changes: 102 additions & 85 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.xml.Node

import org.eclipse.jetty.client.api.Response
import org.eclipse.jetty.proxy.ProxyServlet
import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector}
import org.eclipse.jetty.server._
import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.servlet._
import org.eclipse.jetty.servlets.gzip.GzipHandler
Expand Down Expand Up @@ -279,109 +279,125 @@ private[spark] object JettyUtils extends Logging {

addFilters(handlers, conf)

val gzipHandlers = handlers.map { h =>
h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))

val gzipHandler = new GzipHandler
gzipHandler.setHandler(h)
gzipHandler
// Start the server first, with no connectors.
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

val server = new Server(pool)
val connectors = new ArrayBuffer[ServerConnector]()
val collection = new ContextHandlerCollection

// Create a connector on port currentPort to listen for HTTP requests
val httpConnector = new ServerConnector(
server,
null,
// Call this full constructor to set this, which forces daemon threads:
new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true),
null,
-1,
-1,
new HttpConnectionFactory())
httpConnector.setPort(currentPort)
connectors += httpConnector

val httpsConnector = sslOptions.createJettySslContextFactory() match {
case Some(factory) =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)
connector.setName(SPARK_CONNECTOR_NAME)
connectors += connector

// redirect the HTTP requests to HTTPS port
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
Some(connector)
val server = new Server(pool)

case None =>
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
httpConnector.setName(SPARK_CONNECTOR_NAME)
None
}
val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

val collection = new ContextHandlerCollection
server.setHandler(collection)

// Executor used to create daemon threads for the Jetty connectors.
val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true)

try {
server.start()

// As each acceptor and each selector will use one thread, the number of threads should at
// least be the number of acceptors and selectors plus 1. (See SPARK-13776)
var minThreads = 1
connectors.foreach { connector =>

def newConnector(
connectionFactories: Array[ConnectionFactory],
port: Int): (ServerConnector, Int) = {
val connector = new ServerConnector(
server,
null,
serverExecutor,
null,
-1,
-1,
connectionFactories: _*)
connector.setPort(port)
connector.start()

// Currently we only use "SelectChannelConnector"
// Limit the max acceptor number to 8 so that we don't waste a lot of threads
connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8))
connector.setHost(hostName)
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2

(connector, connector.getLocalPort())
}
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

gzipHandlers.foreach(collection.addHandler)
server.setHandler(collection)

server.setConnectors(connectors.toArray)
try {
server.start()
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
} catch {
case e: Exception =>
server.stop()
pool.stop()
throw e
// If SSL is configured, create the secure connector first.
val securePort = sslOptions.createJettySslContextFactory().map { factory =>
val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0)
val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName
val connectionFactories = AbstractConnectionFactory.getFactories(factory,
new HttpConnectionFactory())

def sslConnect(currentPort: Int): (ServerConnector, Int) = {
newConnector(connectionFactories, currentPort)
}

val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort,
sslConnect, conf, secureServerName)
connector.setName(SPARK_CONNECTOR_NAME)
server.addConnector(connector)
boundPort
}
}

val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
serverName)
ServerInfo(server, boundPort, securePort,
server.getHandler().asInstanceOf[ContextHandlerCollection])
// Bind the HTTP port.
def httpConnect(currentPort: Int): (ServerConnector, Int) = {
newConnector(Array(new HttpConnectionFactory()), currentPort)
}

val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect,
conf, serverName)

// If SSL is configured, then configure redirection in the HTTP connector.
securePort match {
case Some(p) =>
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
val redirector = createRedirectHttpsHandler(p, "https")
collection.addHandler(redirector)
redirector.start()

case None =>
httpConnector.setName(SPARK_CONNECTOR_NAME)
}

server.addConnector(httpConnector)

// Add all the known handlers now that connectors are configured.
handlers.foreach { h =>
h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME))
val gzipHandler = new GzipHandler()
gzipHandler.setHandler(h)
collection.addHandler(gzipHandler)
gzipHandler.start()
}

pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
ServerInfo(server, httpPort, securePort, collection)
} catch {
case e: Exception =>
server.stop()
if (serverExecutor.isStarted()) {
serverExecutor.stop()
}
if (pool.isStarted()) {
pool.stop()
}
throw e
}
}

private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
val redirectHandler: ContextHandler = new ContextHandler
redirectHandler.setContextPath("/")
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME))
redirectHandler.setHandler(new AbstractHandler {
override def handle(
target: String,
Expand All @@ -394,8 +410,7 @@ private[spark] object JettyUtils extends Logging {
val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort,
baseRequest.getRequestURI, baseRequest.getQueryString)
response.setContentLength(0)
response.encodeRedirectURL(httpsURI)
response.sendRedirect(httpsURI)
response.sendRedirect(response.encodeRedirectURL(httpsURI))
baseRequest.setHandled(true)
}
})
Expand Down Expand Up @@ -456,6 +471,8 @@ private[spark] object JettyUtils extends Logging {
new URI(scheme, authority, path, query, null).toString
}

def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray

}

private[spark] case class ServerInfo(
Expand All @@ -465,7 +482,7 @@ private[spark] case class ServerInfo(
private val rootHandler: ContextHandlerCollection) {

def addHandler(handler: ContextHandler): Unit = {
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME))
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
Expand Down
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,14 @@ private[spark] object Utils extends Logging {
}
}

/**
* Returns the user port to try when trying to bind a service. Handles wrapping and skipping
* privileged ports.
*/
def userPort(base: Int, offset: Int): Int = {
(base + offset - 1024) % (65536 - 1024) + 1024
}

/**
* Attempt to start a service on the given port, or fail after a number of attempts.
* Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
Expand Down Expand Up @@ -2229,8 +2237,7 @@ private[spark] object Utils extends Logging {
val tryPort = if (startPort == 0) {
startPort
} else {
// If the new port wraps around, do not try a privilege port
((startPort + offset - 1024) % (65536 - 1024)) + 1024
userPort(startPort, offset)
}
try {
val (service, port) = startService(tryPort)
Expand Down
2 changes: 2 additions & 0 deletions core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val conf = new SparkConf
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.ui.enabled", "false")
conf.set("spark.ssl.ui.port", "4242")
conf.set("spark.ssl.keyStore", keyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
conf.set("spark.ssl.ui.keyStorePassword", "12345")
Expand All @@ -118,6 +119,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts))

assert(opts.enabled === false)
assert(opts.port === Some(4242))
assert(opts.trustStore.isDefined === true)
assert(opts.trustStore.get.getName === "truststore")
assert(opts.trustStore.get.getAbsolutePath === trustStorePath)
Expand Down
28 changes: 27 additions & 1 deletion core/src/test/scala/org/apache/spark/ui/UISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.LocalSparkContext._
import org.apache.spark.util.Utils

class UISuite extends SparkFunSuite {

Expand All @@ -52,13 +53,16 @@ class UISuite extends SparkFunSuite {
(conf, new SecurityManager(conf).getSSLOptions("ui"))
}

private def sslEnabledConf(): (SparkConf, SSLOptions) = {
private def sslEnabledConf(sslPort: Option[Int] = None): (SparkConf, SSLOptions) = {
val keyStoreFilePath = getTestResourcePath("spark.keystore")
val conf = new SparkConf()
.set("spark.ssl.ui.enabled", "true")
.set("spark.ssl.ui.keyStore", keyStoreFilePath)
.set("spark.ssl.ui.keyStorePassword", "123456")
.set("spark.ssl.ui.keyPassword", "123456")
sslPort.foreach { p =>
conf.set("spark.ssl.ui.port", p.toString)
}
(conf, new SecurityManager(conf).getSSLOptions("ui"))
}

Expand Down Expand Up @@ -275,6 +279,28 @@ class UISuite extends SparkFunSuite {
}
}

test("specify both http and https ports separately") {
var socket: ServerSocket = null
var serverInfo: ServerInfo = null
try {
socket = new ServerSocket(0)

// Make sure the SSL port lies way outside the "http + 400" range used as the default.
val baseSslPort = Utils.userPort(socket.getLocalPort(), 10000)
val (conf, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort))

serverInfo = JettyUtils.startJettyServer("0.0.0.0", socket.getLocalPort() + 1,
sslOptions, Seq[ServletContextHandler](), conf, "server1")

val notAllowed = Utils.userPort(serverInfo.boundPort, 400)
assert(serverInfo.securePort.isDefined)
assert(serverInfo.securePort.get != Utils.userPort(serverInfo.boundPort, 400))
} finally {
stopServer(serverInfo)
closeSocket(socket)
}
}

def stopServer(info: ServerInfo): Unit = {
if (info != null) info.stop()
}
Expand Down
14 changes: 14 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,20 @@ Apart from these, the following properties are also available, and may be useful
Configuration</a> for details on hierarchical SSL configuration for services.
</td>
</tr>
<tr>
<td><code>spark.ssl.[namespace].port</code></td>
<td>None</td>
<td>
The port where the SSL service will listen on.

<br />The port must be defined within a namespace configuration; see
<a href="security.html#ssl-configuration">SSL Configuration</a> for the available
namespaces.

<br />When not set, the SSL port will be derived from the non-SSL port for the
same service. A value of "0" will make the service bind to an ephemeral port.
</td>
</tr>
<tr>
<td><code>spark.ssl.enabledAlgorithms</code></td>
<td>Empty</td>
Expand Down
Loading

0 comments on commit 3fc8e8c

Please sign in to comment.