Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20393][Webu UI] Strengthen Spark to prevent XSS vulnerabilities #17686

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

def render(request: HttpServletRequest): Seq[Node] = {
//stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean

val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete)
val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val id = Option(request.getParameter("id"))
//stripXSS is called first to remove suspicious characters used in XSS attacks
val killFlag = Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
Expand Down
26 changes: 14 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024

//stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength = Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt).getOrElse(defaultBytes)

val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand All @@ -55,13 +56,14 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}

//stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength = Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt).getOrElse(defaultBytes)

val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.xml.transform.{RewriteRule, RuleTransformer}

import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
import org.apache.commons.lang3.StringEscapeUtils

/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
Expand Down Expand Up @@ -527,4 +528,27 @@ private[spark] object UIUtils extends Logging {
origHref
}
}

/**
* Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
*
* For more information about XSS testing:
* https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
* https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
*/
def stripXSS(url: String): String = {
var strippedXSSUrl = url
if (strippedXSSUrl != null) {
// Avoid null characters or single quote
strippedXSSUrl = strippedXSSUrl.replaceAll("(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)", "")
strippedXSSUrl = StringEscapeUtils.escapeHtml4(strippedXSSUrl)
}
strippedXSSUrl
}

def stripXSSMap(url: Array[String]): Array[String] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this method is trying to do but it just returns its argument. Do you just mean url.map(stripXSS)? why Map, why url as names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the feedback. The method returned the array of strings stripped of possible XSS issues. I have used your recommendation in the next commit. Map was changed to Array. url is now requestParameter.

var strippedXSSUrl = url
strippedXSSUrl.foreach(stripXSS(_))
strippedXSSUrl
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage

private val sc = parent.sc

//stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val executorId = Option(request.getParameter("executorId")).map { executorId =>
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
jobTag: String,
jobs: Seq[JobUIData],
killEnabled: Boolean): Seq[Node] = {
val allParameters = request.getParameterMap.asScala.toMap
//stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(UIUtils.stripXSSMap(_))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(_) is redundant here

val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))

val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"

val parameterJobPage = request.getParameter(jobTag + ".page")
val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))

val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener

listener.synchronized {
val parameterId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val jobId = parameterId.toInt
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
Expand All @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val jobId = Option(request.getParameter("id")).map(_.toInt)
//stripXSS is called first to remove suspicious characters used in XSS attacks
val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
jobId.foreach { id =>
if (jobProgresslistener.activeJobs.contains(id)) {
sc.foreach(_.cancelJob(id))
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {

def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = Option(request.getParameter("poolname")).map { poolname =>
//stripXSS is called first to remove suspicious characters used in XSS attacks
val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {

def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
val parameterId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterAttempt = request.getParameter("attempt")
val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")

val parameterTaskPage = request.getParameter("task.page")
val parameterTaskSortColumn = request.getParameter("task.sort")
val parameterTaskSortDesc = request.getParameter("task.desc")
val parameterTaskPageSize = request.getParameter("task.pageSize")
val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize")
val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))

val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
Expand Down
14 changes: 7 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
val allParameters = request.getParameterMap().asScala.toMap
//stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(UIUtils.stripXSSMap(_))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))

val parameterStagePage = request.getParameter(stageTag + ".page")
val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize")
val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
val parameterStagePrevPageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))

val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
Expand Down Expand Up @@ -512,4 +513,3 @@ private[ui] class StageDataSource(
}
}
}

5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
Expand All @@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val stageId = Option(request.getParameter("id")).map(_.toInt)
//stripXSS is called first to remove suspicious characters used in XSS attacks
val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
stageId.foreach { id =>
if (progressListener.activeStages.contains(id)) {
sc.foreach(_.cancelStage(id, "killed via the Web UI"))
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener

def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterBlockPage = request.getParameter("block.page")
val parameterBlockSortColumn = request.getParameter("block.sort")
val parameterBlockSortDesc = request.getParameter("block.desc")
val parameterBlockPageSize = request.getParameter("block.pageSize")
val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize")
val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))

val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {

override def render(request: HttpServletRequest): Seq[Node] = {
val driverId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")

val state = parent.scheduler.getDriverState(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val listener = parent.listener

override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized {
val parameterExecutionId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)

/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
//stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val content =
Expand Down Expand Up @@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
}

def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized {
val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse {
//stripXSS is called first to remove suspicious characters used in XSS attacks
val batchTime = Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id => Time(id.toLong)).getOrElse {
throw new IllegalArgumentException(s"Missing id parameter")
}
val formattedBatchTime =
Expand Down