Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20393][Webu UI] Strengthen Spark to prevent XSS vulnerabilities #17686

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

def render(request: HttpServletRequest): Seq[Node] = {
// stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean

val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete)
val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val id = Option(request.getParameter("id"))
// stripXSS is called first to remove suspicious characters used in XSS attacks
val killFlag = Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
Expand Down
26 changes: 14 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024

// stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength = Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt).getOrElse(defaultBytes)

val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand All @@ -55,13 +56,14 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength = Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt).getOrElse(defaultBytes)

val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand Down
22 changes: 22 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.xml.transform.{RewriteRule, RuleTransformer}

import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
import org.apache.commons.lang3.StringEscapeUtils

/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
Expand Down Expand Up @@ -527,4 +528,25 @@ private[spark] object UIUtils extends Logging {
origHref
}
}

/**
* Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
*
* For more information about XSS testing:
* https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
* https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
*/
def stripXSS(requestParameter: String): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you would, add a couple brief tests of this to UIUtilsSuite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work on some.

var requestParameterStripped = requestParameter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the body is simpler as

if (requestParameter == null) {
  null
} else {
  StringEscapeUtils.escapeHtml4(
    requestParameter.replaceAll("(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)", ""))
}

I'm not sure if it will matter for performance but you could pull out the regex as a saved instance variable instead of compiling it each time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed your suggestion.

if (requestParameterStripped != null) {
// Avoid null characters or single quote
requestParameterStripped = requestParameterStripped.replaceAll("(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)", "")
requestParameterStripped = StringEscapeUtils.escapeHtml4(requestParameterStripped)
}
requestParameterStripped
}

def stripXSSArray(requestParameter: Array[String]): Array[String] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence about whether it's worth including this as you can just write _.map(stripXSS) in the two places it's used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check this again; I had a problem with _.map(stripXSS) compiling; but maybe I did something wrong with it.

And in regards to all the calls; I checked for any type of getParameter call; and also looked for HttpServletRequest with other types of calls. Doesn't appear to be any I missed.

requestParameter.map(stripXSS)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage

private val sc = parent.sc

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val executorId = Option(request.getParameter("executorId")).map { executorId =>
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
jobTag: String,
jobs: Seq[JobUIData],
killEnabled: Boolean): Seq[Node] = {
val allParameters = request.getParameterMap.asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(UIUtils.stripXSSArray(_))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: omit (_)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))

val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"

val parameterJobPage = request.getParameter(jobTag + ".page")
val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))

val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener

listener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val jobId = parameterId.toInt
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
Expand All @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val jobId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
jobId.foreach { id =>
if (jobProgresslistener.activeJobs.contains(id)) {
sc.foreach(_.cancelJob(id))
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {

def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = Option(request.getParameter("poolname")).map { poolname =>
// stripXSS is called first to remove suspicious characters used in XSS attacks
val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {

def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterAttempt = request.getParameter("attempt")
val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")

val parameterTaskPage = request.getParameter("task.page")
val parameterTaskSortColumn = request.getParameter("task.sort")
val parameterTaskSortDesc = request.getParameter("task.desc")
val parameterTaskPageSize = request.getParameter("task.pageSize")
val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize")
val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))

val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
Expand Down
14 changes: 7 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
val allParameters = request.getParameterMap().asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(UIUtils.stripXSSArray(_))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))

val parameterStagePage = request.getParameter(stageTag + ".page")
val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize")
val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
val parameterStagePrevPageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))

val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
Expand Down Expand Up @@ -512,4 +513,3 @@ private[ui] class StageDataSource(
}
}
}

5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
Expand All @@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val stageId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
stageId.foreach { id =>
if (progressListener.activeStages.contains(id)) {
sc.foreach(_.cancelStage(id, "killed via the Web UI"))
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener

def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterBlockPage = request.getParameter("block.page")
val parameterBlockSortColumn = request.getParameter("block.sort")
val parameterBlockSortDesc = request.getParameter("block.desc")
val parameterBlockPageSize = request.getParameter("block.pageSize")
val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize")
val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))

val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {

override def render(request: HttpServletRequest): Seq[Node] = {
val driverId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")

val state = parent.scheduler.getDriverState(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val listener = parent.listener

override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized {
val parameterExecutionId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)

/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val content =
Expand Down Expand Up @@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
}

def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized {
val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse {
// stripXSS is called first to remove suspicious characters used in XSS attacks
val batchTime = Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id => Time(id.toLong)).getOrElse {
throw new IllegalArgumentException(s"Missing id parameter")
}
val formattedBatchTime =
Expand Down