diff --git a/.rat-excludes b/.rat-excludes
index 0240e81c45ea2..236c2db05367c 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -91,3 +91,5 @@ help/*
html/*
INDEX
.lintr
+gen-java.*
+.*avpr
diff --git a/R/README.md b/R/README.md
index d7d65b4f0eca5..005f56da1670c 100644
--- a/R/README.md
+++ b/R/README.md
@@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R
#### Build Spark
-Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run
+Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run
```
build/mvn -DskipTests -Psparkr package
```
diff --git a/R/install-dev.bat b/R/install-dev.bat
index 008a5c668bc45..f32670b67de96 100644
--- a/R/install-dev.bat
+++ b/R/install-dev.bat
@@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0..
MKDIR %SPARK_HOME%\R\lib
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
+
+rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
+pushd %SPARK_HOME%\R\lib
+%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
+popd
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 1edd551f8d243..4972bb9217072 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib"
mkdir -p $LIB_DIR
-pushd $FWDIR
+pushd $FWDIR > /dev/null
# Generate Rd files if devtools is installed
Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
@@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo
# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
-popd
+# Zip the SparkR package so that it can be distributed to worker nodes on YARN
+cd $LIB_DIR
+jar cfM "$LIB_DIR/sparkr.zip" SparkR
+
+popd > /dev/null
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index efc85bbc4b316..4949d86d20c91 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -29,7 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
+ 'mllib.R'
'serialize.R'
'sparkR.R'
'utils.R'
- 'zzz.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f857222452d4..a329e14f25aeb 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -10,6 +10,11 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")
+# MLlib integration
+exportMethods("glm",
+ "predict",
+ "summary")
+
# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
@@ -22,6 +27,7 @@ exportMethods("arrange",
"collect",
"columns",
"count",
+ "crosstab",
"describe",
"distinct",
"dropna",
@@ -77,6 +83,7 @@ exportMethods("abs",
"atan",
"atan2",
"avg",
+ "between",
"cast",
"cbrt",
"ceiling",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 60702824acb46..f4c93d3c7dd67 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1314,7 +1314,7 @@ setMethod("except",
#' write.df(df, "myfile", "parquet", "overwrite")
#' }
setMethod("write.df",
- signature(df = "DataFrame", path = 'character'),
+ signature(df = "DataFrame", path = "character"),
function(df, path, source = NULL, mode = "append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -1328,7 +1328,7 @@ setMethod("write.df",
jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] = path
+ options[["path"]] <- path
}
callJMethod(df@sdf, "save", source, jmode, options)
})
@@ -1337,7 +1337,7 @@ setMethod("write.df",
#' @aliases saveDF
#' @export
setMethod("saveDF",
- signature(df = "DataFrame", path = 'character'),
+ signature(df = "DataFrame", path = "character"),
function(df, path, source = NULL, mode = "append", ...){
write.df(df, path, source, mode, ...)
})
@@ -1375,8 +1375,8 @@ setMethod("saveDF",
#' saveAsTable(df, "myfile")
#' }
setMethod("saveAsTable",
- signature(df = "DataFrame", tableName = 'character', source = 'character',
- mode = 'character'),
+ signature(df = "DataFrame", tableName = "character", source = "character",
+ mode = "character"),
function(df, tableName, source = NULL, mode="append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -1554,3 +1554,31 @@ setMethod("fillna",
}
dataFrame(sdf)
})
+
+#' crosstab
+#'
+#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
+#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
+#' non-zero pair frequencies will be returned.
+#'
+#' @param col1 name of the first column. Distinct items will make the first item of each row.
+#' @param col2 name of the second column. Distinct items will make the column names of the output.
+#' @return a local R data.frame representing the contingency table. The first column of each row
+#' will be the distinct values of `col1` and the column names will be the distinct values
+#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
+#' occurrences will have zero as their counts.
+#'
+#' @rdname statfunctions
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- jsonFile(sqlCtx, "/path/to/file.json")
+#' ct = crosstab(df, "title", "gender")
+#' }
+setMethod("crosstab",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2) {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ sct <- callJMethod(statFunctions, "crosstab", col1, col2)
+ collect(dataFrame(sct))
+ })
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 89511141d3ef7..d2d096709245d 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
serializedFuncArr,
rdd@env$prev_serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
} else {
@@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
rdd@env$prev_serializedMode,
serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
}
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 9a743a3411533..110117a18ccbc 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -86,7 +86,9 @@ infer_type <- function(x) {
createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) {
if (is.data.frame(data)) {
# get the names of columns, they will be put into RDD
- schema <- names(data)
+ if (is.null(schema)) {
+ schema <- names(data)
+ }
n <- nrow(data)
m <- ncol(data)
# get rid of factor type
@@ -455,7 +457,7 @@ dropTempTable <- function(sqlContext, tableName) {
read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] <- path
+ options[["path"]] <- path
}
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -504,7 +506,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] <- path
+ options[["path"]] <- path
}
sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options)
dataFrame(sdf)
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
index 2fb6fae55f28c..49162838b8d1a 100644
--- a/R/pkg/R/backend.R
+++ b/R/pkg/R/backend.R
@@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
# TODO: check the status code to output error information
returnStatus <- readInt(conn)
- stopifnot(returnStatus == 0)
+ if (returnStatus != 0) {
+ stop(readString(conn))
+ }
readObject(conn)
}
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 78c7a3037ffac..c811d1dac3bd5 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) {
determineSparkSubmitBin <- function() {
if (.Platform$OS.type == "unix") {
- sparkSubmitBinName = "spark-submit"
+ sparkSubmitBinName <- "spark-submit"
} else {
- sparkSubmitBinName = "spark-submit.cmd"
+ sparkSubmitBinName <- "spark-submit.cmd"
}
sparkSubmitBinName
}
@@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack
jars <- paste("--jars", jars)
}
- if (packages != "") {
+ if (!identical(packages, "")) {
packages <- paste("--packages", packages)
}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 8e4b0f5bf1c4d..2892e1416cc65 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"),
column(jc)
})
+#' between
+#'
+#' Test if the column is between the lower bound and upper bound, inclusive.
+#'
+#' @rdname column
+#'
+#' @param bounds lower and upper bounds
+setMethod("between", signature(x = "Column"),
+ function(x, bounds) {
+ if (is.vector(bounds) && length(bounds) == 2) {
+ jc <- callJMethod(x@jc, "between", bounds[1], bounds[2])
+ column(jc)
+ } else {
+ stop("bounds should be a vector of lower and upper bounds")
+ }
+ })
+
#' Casts the column to a different data type.
#'
#' @rdname column
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index d961bbc383688..6d364f77be7ee 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -23,6 +23,7 @@
# Int -> integer
# String -> character
# Boolean -> logical
+# Float -> double
# Double -> double
# Long -> double
# Array[Byte] -> raw
@@ -101,11 +102,11 @@ readList <- function(con) {
readRaw <- function(con) {
dataLen <- readInt(con)
- data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
+ readBin(con, raw(), as.integer(dataLen), endian = "big")
}
readRawLen <- function(con, dataLen) {
- data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
+ readBin(con, raw(), as.integer(dataLen), endian = "big")
}
readDeserialize <- function(con) {
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 79055b7f18558..a3a121058e165 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -20,7 +20,8 @@
# @rdname aggregateRDD
# @seealso reduce
# @export
-setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
+setGeneric("aggregateRDD",
+ function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
# @rdname cache-methods
# @export
@@ -58,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") })
# @export
setGeneric("countByValue", function(x) { standardGeneric("countByValue") })
+# @rdname statfunctions
+# @export
+setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") })
+
# @rdname distinct
# @export
setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
@@ -249,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues")
# @rdname intersection
# @export
-setGeneric("intersection", function(x, other, numPartitions = 1) {
- standardGeneric("intersection") })
+setGeneric("intersection",
+ function(x, other, numPartitions = 1) {
+ standardGeneric("intersection")
+ })
# @rdname keys
# @export
@@ -484,9 +491,7 @@ setGeneric("sample",
#' @rdname sample
#' @export
setGeneric("sample_frac",
- function(x, withReplacement, fraction, seed) {
- standardGeneric("sample_frac")
- })
+ function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
#' @rdname saveAsParquetFile
#' @export
@@ -548,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn
#' @rdname withColumnRenamed
#' @export
-setGeneric("withColumnRenamed", function(x, existingCol, newCol) {
- standardGeneric("withColumnRenamed") })
+setGeneric("withColumnRenamed",
+ function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") })
###################### Column Methods ##########################
@@ -566,6 +571,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") })
#' @export
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
+#' @rdname column
+#' @export
+setGeneric("between", function(x, bounds) { standardGeneric("between") })
+
#' @rdname column
#' @export
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
@@ -656,3 +665,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
+
+#' @rdname glm
+#' @export
+setGeneric("glm")
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 8f1c68f7c4d28..576ac72f40fc0 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -87,7 +87,7 @@ setMethod("count",
setMethod("agg",
signature(x = "GroupedData"),
function(x, ...) {
- cols = list(...)
+ cols <- list(...)
stopifnot(length(cols) > 0)
if (is.character(cols[[1]])) {
cols <- varargsToEnv(...)
@@ -97,7 +97,7 @@ setMethod("agg",
if (!is.null(ns)) {
for (n in ns) {
if (n != "") {
- cols[[n]] = alias(cols[[n]], n)
+ cols[[n]] <- alias(cols[[n]], n)
}
}
}
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
new file mode 100644
index 0000000000000..efddcc1d8d71c
--- /dev/null
+++ b/R/pkg/R/mllib.R
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# mllib.R: Provides methods for MLlib integration
+
+#' @title S4 class that represents a PipelineModel
+#' @param model A Java object reference to the backing Scala PipelineModel
+#' @export
+setClass("PipelineModel", representation(model = "jobj"))
+
+#' Fits a generalized linear model
+#'
+#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
+#'
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '+', '-', and '.'.
+#' @param data DataFrame for training
+#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
+#' @param lambda Regularization parameter
+#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
+#' @return a fitted MLlib model
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' data(iris)
+#' df <- createDataFrame(sqlContext, iris)
+#' model <- glm(Sepal_Length ~ Sepal_Width, df)
+#'}
+setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
+ function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
+ family <- match.arg(family)
+ model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "fitRModelFormula", deparse(formula), data@sdf, family, lambda,
+ alpha)
+ return(new("PipelineModel", model = model))
+ })
+
+#' Make predictions from a model
+#'
+#' Makes predictions from a model produced by glm(), similarly to R's predict().
+#'
+#' @param model A fitted MLlib model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted values
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#'}
+setMethod("predict", signature(object = "PipelineModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
+ })
+
+#' Get the summary of a model
+#'
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#'
+#' @param model A fitted MLlib model
+#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
+#' summary.glm for more information.
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "PipelineModel"),
+ function(object) {
+ features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelFeatures", object@model)
+ weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelWeights", object@model)
+ coefficients <- as.matrix(unlist(weights))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 7f902ba8e683e..83801d3209700 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -202,8 +202,8 @@ setMethod("partitionBy",
packageNamesArr <- serialize(.sparkREnv$.packages,
connection = NULL)
- broadcastArr <- lapply(ls(.broadcastNames), function(name) {
- get(name, .broadcastNames) })
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
jrdd <- getJRDD(x)
# We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])],
@@ -215,7 +215,6 @@ setMethod("partitionBy",
serializedHashFuncBytes,
getSerializedMode(x),
packageNamesArr,
- as.character(.sparkREnv$libname),
broadcastArr,
callJMethod(jrdd, "classTag"))
@@ -560,8 +559,8 @@ setMethod("join",
# Left outer join two RDDs
#
# @description
-# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
@@ -597,8 +596,8 @@ setMethod("leftOuterJoin",
# Right outer join two RDDs
#
# @description
-# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
@@ -634,8 +633,8 @@ setMethod("rightOuterJoin",
# Full outer join two RDDs
#
# @description
-# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 15e2bdbd55d79..79c744ef29c23 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
#' @param ... further arguments passed to or from other methods
print.structType <- function(x, ...) {
cat("StructType\n",
- sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
- "\", type = \"", field$dataType.toString(),
- "\", nullable = ", field$nullable(), "\n",
- sep = "") })
- , sep = "")
+ sapply(x$fields(),
+ function(field) {
+ paste("|-", "name = \"", field$name(),
+ "\", type = \"", field$dataType.toString(),
+ "\", nullable = ", field$nullable(), "\n",
+ sep = "")
+ }),
+ sep = "")
}
#' structField
@@ -123,6 +126,7 @@ structField.character <- function(x, type, nullable = TRUE) {
}
options <- c("byte",
"integer",
+ "float",
"double",
"numeric",
"character",
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 78535eff0d2f6..311021e5d8473 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -140,8 +140,8 @@ writeType <- function(con, class) {
jobj = "j",
environment = "e",
Date = "D",
- POSIXlt = 't',
- POSIXct = 't',
+ POSIXlt = "t",
+ POSIXct = "t",
stop(paste("Unsupported type for serialization", class)))
writeBin(charToRaw(type), con)
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 86233e01db365..e83104f116422 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -17,16 +17,13 @@
.sparkREnv <- new.env()
-sparkR.onLoad <- function(libname, pkgname) {
- .sparkREnv$libname <- libname
-}
-
# Utility function that returns TRUE if we have an active connection to the
# backend and FALSE otherwise
connExists <- function(env) {
tryCatch({
exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]])
- }, error = function(err) {
+ },
+ error = function(err) {
return(FALSE)
})
}
@@ -80,7 +77,6 @@ sparkR.stop <- function() {
#' @param sparkEnvir Named list of environment variables to set on worker nodes.
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
-#' @param sparkRLibDir The path where R is installed on the worker nodes.
#' @param sparkPackages Character string vector of packages from spark-packages.org
#' @export
#' @examples
@@ -101,24 +97,21 @@ sparkR.init <- function(
sparkEnvir = list(),
sparkExecutorEnv = list(),
sparkJars = "",
- sparkRLibDir = "",
sparkPackages = "") {
if (exists(".sparkRjsc", envir = .sparkREnv)) {
- cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
+ cat(paste("Re-using existing Spark Context.",
+ "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n"))
return(get(".sparkRjsc", envir = .sparkREnv))
}
- sparkMem <- Sys.getenv("SPARK_MEM", "1024m")
jars <- suppressWarnings(normalizePath(as.character(sparkJars)))
# Classpath separator is ";" on Windows
# URI needs four /// as from http://stackoverflow.com/a/18522792
if (.Platform$OS.type == "unix") {
- collapseChar <- ":"
uriSep <- "//"
} else {
- collapseChar <- ";"
uriSep <- "////"
}
@@ -145,7 +138,7 @@ sparkR.init <- function(
if (!file.exists(path)) {
stop("JVM is not ready after 10 seconds")
}
- f <- file(path, open='rb')
+ f <- file(path, open="rb")
backendPort <- readInt(f)
monitorPort <- readInt(f)
close(f)
@@ -161,7 +154,8 @@ sparkR.init <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort)
- }, error = function(err) {
+ },
+ error = function(err) {
stop("Failed to connect JVM\n")
})
@@ -169,10 +163,6 @@ sparkR.init <- function(
sparkHome <- normalizePath(sparkHome)
}
- if (nchar(sparkRLibDir) != 0) {
- .sparkREnv$libname <- sparkRLibDir
- }
-
sparkEnvirMap <- new.env()
for (varname in names(sparkEnvir)) {
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
@@ -180,14 +170,16 @@ sparkR.init <- function(
sparkExecutorEnvMap <- new.env()
if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
- sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
+ sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
+ paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
}
for (varname in names(sparkExecutorEnv)) {
sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
}
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
- localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
+ localJarPaths <- sapply(nonEmptyJars,
+ function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
# Set the start time to identify jobjs
# Seconds resolution is good enough for this purpose, so use ints
@@ -274,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) {
ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
- }, error = function(err) {
+ },
+ error = function(err) {
stop("Spark SQL is not built with Hive support")
})
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 13cec0f712fb4..3f45589a50443 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
if (isInstanceOf(obj, "scala.Tuple2")) {
# JavaPairRDD[Array[Byte], Array[Byte]].
- keyBytes = callJMethod(obj, "_1")
- valBytes = callJMethod(obj, "_2")
+ keyBytes <- callJMethod(obj, "_1")
+ valBytes <- callJMethod(obj, "_2")
res <- list(unserialize(keyBytes),
unserialize(valBytes))
} else {
@@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY",
"MEMORY_ONLY_SER_2",
"OFF_HEAP")) {
match.arg(newLevel)
+ storageLevelClass <- "org.apache.spark.storage.StorageLevel"
storageLevel <- switch(newLevel,
- "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"),
- "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"),
- "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"),
- "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"),
- "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"),
- "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"),
- "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"),
- "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"),
- "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"),
- "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"),
- "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP"))
+ "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"),
+ "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"),
+ "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"),
+ "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"),
+ "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass,
+ "MEMORY_AND_DISK_SER"),
+ "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass,
+ "MEMORY_AND_DISK_SER_2"),
+ "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"),
+ "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"),
+ "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"),
+ "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"),
+ "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP"))
}
# Utility function for functions where an argument needs to be integer but we want to allow
@@ -387,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else { # if node[[1]] is length of 1, check for some R special functions.
+ } else {
+ # if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
- if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
+ if (nodeChar == "{" || nodeChar == "(") {
+ # Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
- nodeChar == "<<-") { # Assignment Ops.
+ nodeChar == "<<-") {
+ # Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
@@ -405,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else if (nodeChar == "function") { # Function definition.
+ } else if (nodeChar == "function") {
+ # Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else if (nodeChar == "$") { # Skip the field.
+ } else if (nodeChar == "$") {
+ # Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
@@ -426,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
- if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
+ if (!nodeChar %in% defVars$data) {
+ # Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
@@ -436,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
if (!isNamespace(func.env) ||
- (getNamespaceName(func.env) == "SparkR" &&
- !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
+ (getNamespaceName(func.env) == "SparkR" &&
+ !(nodeChar %in% getNamespaceExports("SparkR")))) {
+ # Only include SparkR internals.
+
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
- if (is.function(obj)) { # If the node is a function call.
+ if (is.function(obj)) {
+ # If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
})
- if (sum(found) > 0) { # If function has been examined, ignore.
+ if (sum(found) > 0) {
+ # If function has been examined, ignore.
break
}
# Function has not been examined, record it and recursively clean its closure.
@@ -492,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
- for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
+ for (i in 1:(length(argNames) - 1)) {
+ # Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
@@ -545,9 +559,11 @@ mergePartitions <- function(rdd, zip) {
lengthOfKeys <- part[[len - lengthOfValues]]
stopifnot(len == lengthOfKeys + lengthOfValues)
- # For zip operation, check if corresponding partitions of both RDDs have the same number of elements.
+ # For zip operation, check if corresponding partitions
+ # of both RDDs have the same number of elements.
if (zip && lengthOfKeys != lengthOfValues) {
- stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
+ stop(paste("Can only zip RDDs with same number of elements",
+ "in each pair of corresponding partitions."))
}
if (lengthOfKeys > 1) {
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
index 8fe711b622086..2a8a8213d0849 100644
--- a/R/pkg/inst/profile/general.R
+++ b/R/pkg/inst/profile/general.R
@@ -16,7 +16,7 @@
#
.First <- function() {
- home <- Sys.getenv("SPARK_HOME")
- .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
+ packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
+ .libPaths(c(packageDir, .libPaths()))
Sys.setenv(NOAWT=1)
}
diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R
index ccaea18ecab2a..f2452ed97d2ea 100644
--- a/R/pkg/inst/tests/test_binaryFile.R
+++ b/R/pkg/inst/tests/test_binaryFile.R
@@ -20,7 +20,7 @@ context("functions on binary files")
# JavaSparkContext handle
sc <- sparkR.init()
-mockFile = c("Spark is pretty.", "Spark is awesome.")
+mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("saveAsObjectFile()/objectFile() following textFile() works", {
fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
index 3be8c65a6c1a0..dca0657c57e0d 100644
--- a/R/pkg/inst/tests/test_binary_function.R
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", {
expect_equal(actual,
list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))))
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R
index 30b05c1a2afcd..8a20991f89af8 100644
--- a/R/pkg/inst/tests/test_client.R
+++ b/R/pkg/inst/tests/test_client.R
@@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", {
expect_equal(gsub("[[:space:]]", "", args),
"")
})
+
+test_that("multiple packages don't produce a warning", {
+ expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning()))
+})
diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R
index 844d86f3cc97f..cc1faeabffe30 100644
--- a/R/pkg/inst/tests/test_includeJAR.R
+++ b/R/pkg/inst/tests/test_includeJAR.R
@@ -18,8 +18,8 @@ context("include an external JAR in SparkContext")
runScript <- function() {
sparkHome <- Sys.getenv("SPARK_HOME")
- jarPath <- paste("--jars",
- shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar")))
+ sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"
+ jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath)))
scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R")
submitPath <- file.path(sparkHome, "bin/spark-submit")
res <- system2(command = submitPath,
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
new file mode 100644
index 0000000000000..f272de78ad4a6
--- /dev/null
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -0,0 +1,61 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("MLlib functions")
+
+# Tests for MLlib functions in SparkR
+
+sc <- sparkR.init()
+
+sqlContext <- sparkRSQL.init(sc)
+
+test_that("glm and predict", {
+ training <- createDataFrame(sqlContext, iris)
+ test <- select(training, "Sepal_Length")
+ model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
+ prediction <- predict(model, test)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+})
+
+test_that("predictions match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
+test_that("dot minus and intercept vs native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ . - Species + 0, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
+test_that("summary coefficients match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ coefs <- as.vector(stats$coefficients)
+ rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
+ expect_true(all(abs(rCoefs - coefs) < 1e-6))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+})
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index fc3c01d837de4..6c3aaab8c711e 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", {
expect_equal(actual,
list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
@@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", {
actual <- collect(cartesian(rdd, emptyRdd))
expect_equal(actual, list())
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
@@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3)))
rdd2 <- parallelize(sc, list(list(1,1), list(2,4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
- expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL)))
+ expected <- list(list(1, list(2, 1)), list(1, list(3, 1)),
+ list(2, list(NULL, 4)), list(3, list(3, NULL)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1)))
rdd2 <- parallelize(sc, list(list("a",1), list("b",4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
- expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL)))
+ expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)),
+ list("a", list(3, 1)), list("c", list(1, NULL)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
@@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", {
rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
- sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
+ sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)),
+ list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
- sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
+ sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)),
+ list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
})
test_that("sortByKey() on pairwise RDDs", {
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 0e4235ea8b4b3..61c8a7ec7d837 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -57,9 +57,9 @@ test_that("infer types", {
expect_equal(infer_type(as.Date("2015-03-11")), "date")
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
expect_equal(infer_type(c(1L, 2L)),
- list(type = 'array', elementType = "integer", containsNull = TRUE))
+ list(type = "array", elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(1L, 2L)),
- list(type = 'array', elementType = "integer", containsNull = TRUE))
+ list(type = "array", elementType = "integer", containsNull = TRUE))
testStruct <- infer_type(list(a = 1L, b = "2"))
expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
@@ -108,6 +108,33 @@ test_that("create DataFrame from RDD", {
expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ df <- jsonFile(sqlContext, jsonPathNa)
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ },
+ error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
+ insertInto(df, "people")
+ expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16))
+ expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5))
+
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ df2 <- createDataFrame(sqlContext, df.toRDD, schema)
+ expect_equal(columns(df2), c("name", "age", "height"))
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
+
+ localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
+ df <- createDataFrame(sqlContext, localDF, schema)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
+ expect_equal(columns(df), c("name", "age", "height"))
+ expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10))
})
test_that("convert NAs to null type in DataFrames", {
@@ -391,7 +418,7 @@ test_that("collect() and take() on a DataFrame return the same number of rows an
expect_equal(ncol(collect(df)), ncol(take(df, 10)))
})
-test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", {
+test_that("multiple pipeline transformations result in an RDD with the correct values", {
df <- jsonFile(sqlContext, jsonPath)
first <- lapply(df, function(row) {
row$age <- row$age + 5
@@ -576,7 +603,8 @@ test_that("write.df() as parquet file", {
test_that("test HiveContext", {
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
- }, error = function(err) {
+ },
+ error = function(err) {
skip("Hive is not build with SparkSQL, skipped")
})
df <- createExternalTable(hiveCtx, "json", jsonPath, "json")
@@ -612,6 +640,18 @@ test_that("column functions", {
c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
c9 <- toDegrees(c) + toRadians(c)
+
+ df <- jsonFile(sqlContext, jsonPath)
+ df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20)))
+ expect_equal(collect(df2)[[2, 1]], TRUE)
+ expect_equal(collect(df2)[[2, 2]], FALSE)
+ expect_equal(collect(df2)[[3, 1]], FALSE)
+ expect_equal(collect(df2)[[3, 2]], TRUE)
+
+ df3 <- select(df, between(df$name, c("Apache", "Spark")))
+ expect_equal(collect(df3)[[1, 1]], TRUE)
+ expect_equal(collect(df3)[[2, 1]], FALSE)
+ expect_equal(collect(df3)[[3, 1]], TRUE)
})
test_that("column binary mathfunctions", {
@@ -756,7 +796,14 @@ test_that("toJSON() returns an RDD of the correct values", {
test_that("showDF()", {
df <- jsonFile(sqlContext, jsonPath)
s <- capture.output(showDF(df))
- expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n")
+ expected <- paste("+----+-------+\n",
+ "| age| name|\n",
+ "+----+-------+\n",
+ "|null|Michael|\n",
+ "| 30| Andy|\n",
+ "| 19| Justin|\n",
+ "+----+-------+\n", sep="")
+ expect_output(s , expected)
})
test_that("isLocal()", {
@@ -942,6 +989,24 @@ test_that("fillna() on a DataFrame", {
expect_identical(expected, actual)
})
+test_that("crosstab() on a DataFrame", {
+ rdd <- lapply(parallelize(sc, 0:3), function(x) {
+ list(paste0("a", x %% 3), paste0("b", x %% 2))
+ })
+ df <- toDF(rdd, list("a", "b"))
+ ct <- crosstab(df, "a", "b")
+ ordered <- ct[order(ct$a_b),]
+ row.names(ordered) <- NULL
+ expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0),
+ stringsAsFactors = FALSE, row.names = NULL)
+ expect_identical(expected, ordered)
+})
+
+test_that("SQL error message is returned from JVM", {
+ retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
+ expect_equal(grepl("Table Not Found: blah", retError), TRUE)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R
index 58318dfef71ab..a9cf83dbdbdb1 100644
--- a/R/pkg/inst/tests/test_textFile.R
+++ b/R/pkg/inst/tests/test_textFile.R
@@ -20,7 +20,7 @@ context("the textFile() function")
# JavaSparkContext handle
sc <- sparkR.init()
-mockFile = c("Spark is pretty.", "Spark is awesome.")
+mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("textFile() on a local file returns an RDD", {
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R
index aa0d2a66b9082..12df4cf4f65b7 100644
--- a/R/pkg/inst/tests/test_utils.R
+++ b/R/pkg/inst/tests/test_utils.R
@@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", {
# Test for overriding variables in base namespace (Issue: SparkR-196).
nums <- as.list(1:10)
rdd <- parallelize(sc, nums, 2L)
- t = 4 # Override base::t in .GlobalEnv.
+ t <- 4 # Override base::t in .GlobalEnv.
f <- function(x) { x > t }
newF <- cleanClosure(f)
env <- environment(newF)
diff --git a/R/run-tests.sh b/R/run-tests.sh
index e82ad0ba2cd06..18a1e13bdc655 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
if [[ $FAILED != 0 ]]; then
diff --git a/bin/pyspark b/bin/pyspark
index f9dbddfa53560..8f2a3b5a7717b 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -82,4 +82,4 @@ fi
export PYSPARK_DRIVER_PYTHON
export PYSPARK_DRIVER_PYTHON_OPTS
-exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@"
+exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@"
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 45e9e3def5121..3c6169983e76b 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
-call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %*
+call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %*
diff --git a/bin/spark-shell b/bin/spark-shell
index a6dc863d83fc6..00ab7afd118b5 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -47,11 +47,11 @@ function main() {
# (see https://github.com/sbt/sbt/issues/562).
stty -icanon min 1 -echo > /dev/null 2>&1
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
- "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
+ "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
stty icanon echo > /dev/null 2>&1
else
export SPARK_SUBMIT_OPTS
- "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
+ "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
fi
}
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 251309d67f860..b9b0f510d7f5d 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"
:run_shell
-%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
+%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %*
diff --git a/build/mvn b/build/mvn
index e8364181e8230..f62f61ee1c416 100755
--- a/build/mvn
+++ b/build/mvn
@@ -112,10 +112,17 @@ install_scala() {
# the environment
ZINC_PORT=${ZINC_PORT:-"3030"}
+# Check for the `--force` flag dictating that `mvn` should be downloaded
+# regardless of whether the system already has a `mvn` install
+if [ "$1" == "--force" ]; then
+ FORCE_MVN=1
+ shift
+fi
+
# Install Maven if necessary
MVN_BIN="$(command -v mvn)"
-if [ ! "$MVN_BIN" ]; then
+if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then
install_mvn
fi
@@ -139,5 +146,7 @@ fi
# Set any `mvn` options if not already present
export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
+echo "Using \`mvn\` from path: $MVN_BIN"
+
# Last, call the `mvn` command as usual
${MVN_BIN} "$@"
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 504be48b358fa..7930a38b9674a 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -51,9 +51,13 @@ acquire_sbt_jar () {
printf "Attempting to fetch sbt\n"
JAR_DL="${JAR}.part"
if [ $(command -v curl) ]; then
- (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
+ (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
+ (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
+ mv "${JAR_DL}" "${JAR}"
elif [ $(command -v wget) ]; then
- (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
+ (wget --quiet ${URL1} -O "${JAR_DL}" ||\
+ (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
+ mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
index 3a2a88219818f..27006e45e932b 100644
--- a/conf/log4j.properties.template
+++ b/conf/log4j.properties.template
@@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN
log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
+
+# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
+log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
+log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/pom.xml b/core/pom.xml
index aee0d92620606..202678779150b 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -34,6 +34,11 @@
Spark Project Corehttp://spark.apache.org/
+
+ org.apache.avro
+ avro-mapred
+ ${avro.mapred.classifier}
+ com.google.guavaguava
@@ -261,7 +266,7 @@
com.fasterxml.jackson.module
- jackson-module-scala_2.10
+ jackson-module-scala_${scala.binary.version}org.apache.derby
@@ -281,7 +286,7 @@
org.tachyonprojecttachyon-client
- 0.6.4
+ 0.7.0org.apache.hadoop
@@ -292,36 +297,12 @@
curator-recipes
- org.eclipse.jetty
- jetty-jsp
-
-
- org.eclipse.jetty
- jetty-webapp
-
-
- org.eclipse.jetty
- jetty-server
-
-
- org.eclipse.jetty
- jetty-servlet
-
-
- junit
- junit
+ org.tachyonproject
+ tachyon-underfs-glusterfs
- org.powermock
- powermock-module-junit4
-
-
- org.powermock
- powermock-api-mockito
-
-
- org.apache.curator
- curator-test
+ org.tachyonproject
+ tachyon-underfs-s3
@@ -342,6 +323,16 @@
xml-apistest
+
+ org.hamcrest
+ hamcrest-core
+ test
+
+
+ org.hamcrest
+ hamcrest-library
+ test
+ org.mockitomockito-core
@@ -358,18 +349,13 @@
test
- org.hamcrest
- hamcrest-core
- test
-
-
- org.hamcrest
- hamcrest-library
+ com.novocode
+ junit-interfacetest
- com.novocode
- junit-interface
+ org.apache.curator
+ curator-testtest
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
index 646496f313507..fa9acf0a15b88 100644
--- a/core/src/main/java/org/apache/spark/JavaSparkListener.java
+++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java
@@ -17,23 +17,7 @@
package org.apache.spark;
-import org.apache.spark.scheduler.SparkListener;
-import org.apache.spark.scheduler.SparkListenerApplicationEnd;
-import org.apache.spark.scheduler.SparkListenerApplicationStart;
-import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
-import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
-import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorAdded;
-import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
-import org.apache.spark.scheduler.SparkListenerJobEnd;
-import org.apache.spark.scheduler.SparkListenerJobStart;
-import org.apache.spark.scheduler.SparkListenerStageCompleted;
-import org.apache.spark.scheduler.SparkListenerStageSubmitted;
-import org.apache.spark.scheduler.SparkListenerTaskEnd;
-import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
-import org.apache.spark.scheduler.SparkListenerTaskStart;
-import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+import org.apache.spark.scheduler.*;
/**
* Java clients should extend this class instead of implementing
@@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
@Override
public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
+
+ @Override
+ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { }
+
}
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
index fbc5666959055..1214d05ba6063 100644
--- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
onEvent(executorRemoved);
}
+
+ @Override
+ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) {
+ onEvent(blockUpdated);
+ }
+
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
similarity index 91%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
index 3f746b886bc9b..0399abc63c235 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
+++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.serializer;
import java.io.IOException;
import java.io.InputStream;
@@ -24,9 +24,7 @@
import scala.reflect.ClassTag;
-import org.apache.spark.serializer.DeserializationStream;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.PlatformDependent;
/**
@@ -35,7 +33,8 @@
* `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
* around this, we pass a dummy no-op serializer.
*/
-final class DummySerializerInstance extends SerializerInstance {
+@Private
+public final class DummySerializerInstance extends SerializerInstance {
public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index d3d6280284beb..0b8b604e18494 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter<
private final Serializer serializer;
/** Array of file writers, one for each partition */
- private BlockObjectWriter[] partitionWriters;
+ private DiskBlockObjectWriter[] partitionWriters;
public BypassMergeSortShuffleWriter(
SparkConf conf,
@@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException {
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
- partitionWriters = new BlockObjectWriter[numPartitions];
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2 tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
@@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException {
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
}
@@ -169,7 +169,7 @@ public void stop() throws IOException {
if (partitionWriters != null) {
try {
final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
writer.revertPartialWritesAndClose();
if (!diskBlockManager.getFile(writer.blockId()).delete()) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 9e9ed94b7890c..1aa6ba4201261 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -30,6 +30,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
@@ -58,14 +59,14 @@ final class UnsafeShuffleExternalSorter {
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
- private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
- @VisibleForTesting
- static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
private final int initialSize;
private final int numPartitions;
+ private final int pageSizeBytes;
+ @VisibleForTesting
+ final int maxRecordSizeBytes;
private final TaskMemoryManager memoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
@@ -108,7 +109,10 @@ public UnsafeShuffleExternalSorter(
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
-
+ this.pageSizeBytes = (int) Math.min(
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
+ conf.getSizeAsBytes("spark.buffer.pageSize", "64m"));
+ this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
initializeForWriting();
}
@@ -156,7 +160,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
- BlockObjectWriter writer;
+ DiskBlockObjectWriter writer;
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
@@ -271,7 +275,11 @@ void spill() throws IOException {
}
private long getMemoryUsage() {
- return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ long totalPageSize = 0;
+ for (MemoryBlock page : allocatedPages) {
+ totalPageSize += page.size();
+ }
+ return sorter.getMemoryUsage() + totalPageSize;
}
private long freeMemory() {
@@ -345,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
// TODO: we should track metrics on the amount of space wasted when we roll over to a new page
// without using the free space at the end of the current page. We should also do this for
// BytesToBytesMap.
- if (requiredSpace > PAGE_SIZE) {
+ if (requiredSpace > pageSizeBytes) {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
- PAGE_SIZE + ")");
+ pageSizeBytes + ")");
} else {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
- if (memoryAcquired < PAGE_SIZE) {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
- if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
- throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
- currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPage = memoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = PAGE_SIZE;
+ freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index 764578b181422..d47d6fc9c2ac4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -129,6 +129,11 @@ public UnsafeShuffleWriter(
open();
}
+ @VisibleForTesting
+ public int maxRecordSizeBytes() {
+ return sorter.maxRecordSizeBytes;
+ }
+
/**
* This convenience method should only be called in test code.
*/
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
new file mode 100644
index 0000000000000..45b78829e4cf7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
+ * comparisons, such as lexicographic comparison for strings.
+ */
+@Private
+public abstract class PrefixComparator {
+ public abstract int compare(long prefix1, long prefix2);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
new file mode 100644
index 0000000000000..4d7e5b3dfba6e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.primitives.UnsignedLongs;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.Utils;
+
+@Private
+public class PrefixComparators {
+ private PrefixComparators() {}
+
+ public static final StringPrefixComparator STRING = new StringPrefixComparator();
+ public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+ public static final LongPrefixComparator LONG = new LongPrefixComparator();
+ public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
+ public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+ public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
+
+ public static final class StringPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
+ }
+
+ public static long computePrefix(UTF8String value) {
+ return value == null ? 0L : value.getPrefix();
+ }
+ }
+
+ public static final class StringPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
+ }
+ }
+
+ public static final class LongPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long a, long b) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+ }
+
+ public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+ }
+
+ public static final class DoublePrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return Utils.nanSafeCompareDoubles(a, b);
+ }
+
+ public static long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
+ }
+
+ public static final class DoublePrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long bPrefix, long aPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return Utils.nanSafeCompareDoubles(a, b);
+ }
+
+ public static long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
new file mode 100644
index 0000000000000..09e4258792204
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+/**
+ * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
+ * prefix, this may simply return 0.
+ */
+public abstract class RecordComparator {
+
+ /**
+ * Compare two records for order.
+ *
+ * @return a negative integer, zero, or a positive integer as the first record is less than,
+ * equal to, or greater than the second.
+ */
+ public abstract int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
new file mode 100644
index 0000000000000..0c4ebde407cfc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+final class RecordPointerAndKeyPrefix {
+ /**
+ * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+ * description of how these addresses are encoded.
+ */
+ public long recordPointer;
+
+ /**
+ * A key prefix, for use in comparisons.
+ */
+ public long keyPrefix;
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
new file mode 100644
index 0000000000000..866e0b4151577
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.runtime.AbstractFunction0;
+import scala.runtime.BoxedUnit;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * External sorter based on {@link UnsafeInMemorySorter}.
+ */
+public final class UnsafeExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+
+ private final long pageSizeBytes;
+ private final PrefixComparator prefixComparator;
+ private final RecordComparator recordComparator;
+ private final int initialSize;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList allocatedPages = new LinkedList();
+
+ // These variables are reset after spilling:
+ private UnsafeInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ private final LinkedList spillWriters = new LinkedList<>();
+
+ public UnsafeExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ SparkConf conf) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.initialSize = initialSize;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
+ initializeForWriting();
+
+ // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
+ // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
+ // does not fully consume the sorter's output (e.g. sort followed by limit).
+ taskContext.addOnCompleteCallback(new AbstractFunction0() {
+ @Override
+ public BoxedUnit apply() {
+ freeMemory();
+ return null;
+ }
+ });
+ }
+
+ // TODO: metrics tracking + integration with shuffle write metrics
+ // need to connect the write metrics to task metrics so we count the spill IO somewhere.
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ this.writeMetrics = new ShuffleWriteMetrics();
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter =
+ new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ public void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size(),
+ spillWriters.size() > 1 ? " times" : " time");
+
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ sorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ long totalPageSize = 0;
+ for (MemoryBlock page : allocatedPages) {
+ totalPageSize += page.size();
+ }
+ return sorter.getMemoryUsage() + totalPageSize;
+ }
+
+ @VisibleForTesting
+ public int getNumberOfAllocatedPages() {
+ return allocatedPages.size();
+ }
+
+ public long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ // TODO: merge these steps to first calculate total memory requirements for this insert,
+ // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
+ // data page.
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > pageSizeBytes) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ pageSizeBytes + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquired < pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(pageSizeBytes);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = pageSizeBytes;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ long prefix) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ sorter.insertRecord(recordAddress, prefix);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+ if (spillWriters.isEmpty()) {
+ return inMemoryIterator;
+ } else {
+ final UnsafeSorterSpillMerger spillMerger =
+ new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ spillMerger.addSpill(spillWriter.getReader(blockManager));
+ }
+ spillWriters.clear();
+ if (inMemoryIterator.hasNext()) {
+ spillMerger.addSpill(inMemoryIterator);
+ }
+ return spillMerger.getSortedIterator();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
new file mode 100644
index 0000000000000..fc34ad9cff369
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
+ * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm
+ * compares records, it will first compare the stored key prefixes; if the prefixes are not equal,
+ * then we do not need to traverse the record pointers to compare the actual records. Avoiding these
+ * random memory accesses improves cache hit rates.
+ */
+public final class UnsafeInMemorySorter {
+
+ private static final class SortComparator implements Comparator {
+
+ private final RecordComparator recordComparator;
+ private final PrefixComparator prefixComparator;
+ private final TaskMemoryManager memoryManager;
+
+ SortComparator(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ TaskMemoryManager memoryManager) {
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.memoryManager = memoryManager;
+ }
+
+ @Override
+ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
+ final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
+ if (prefixComparisonResult == 0) {
+ final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
+ final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length
+ final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
+ final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length
+ return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ }
+
+ private final TaskMemoryManager memoryManager;
+ private final Sorter sorter;
+ private final Comparator sortComparator;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeInMemorySorter(
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize * 2];
+ this.memoryManager = memoryManager;
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pointerArrayInsertPosition / 2;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 2 < pointerArray.length;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a 4-byte integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix) {
+ if (!hasSpaceForAnotherRecord()) {
+ expandPointerArray();
+ }
+ pointerArray[pointerArrayInsertPosition] = recordPointer;
+ pointerArrayInsertPosition++;
+ pointerArray[pointerArrayInsertPosition] = keyPrefix;
+ pointerArrayInsertPosition++;
+ }
+
+ private static final class SortedIterator extends UnsafeSorterIterator {
+
+ private final TaskMemoryManager memoryManager;
+ private final int sortBufferInsertPosition;
+ private final long[] sortBuffer;
+ private int position = 0;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+
+ SortedIterator(
+ TaskMemoryManager memoryManager,
+ int sortBufferInsertPosition,
+ long[] sortBuffer) {
+ this.memoryManager = memoryManager;
+ this.sortBufferInsertPosition = sortBufferInsertPosition;
+ this.sortBuffer = sortBuffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < sortBufferInsertPosition;
+ }
+
+ @Override
+ public void loadNext() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortBuffer[position];
+ baseObject = memoryManager.getPage(recordPointer);
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
+ recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ keyPrefix = sortBuffer[position + 1];
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
+ return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
new file mode 100644
index 0000000000000..d09c728a7a638
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+/**
+ * Supports sorting an array of (record pointer, key prefix) pairs.
+ * Used in {@link UnsafeInMemorySorter}.
+ *
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+final class UnsafeSortDataFormat extends SortDataFormat {
+
+ public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+
+ private UnsafeSortDataFormat() { }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix newKey() {
+ return new RecordPointerAndKeyPrefix();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data[pos * 2];
+ reuse.keyPrefix = data[pos * 2 + 1];
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ long tempPointer = data[pos0 * 2];
+ long tempKeyPrefix = data[pos0 * 2 + 1];
+ data[pos0 * 2] = data[pos1 * 2];
+ data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
+ data[pos1 * 2] = tempPointer;
+ data[pos1 * 2 + 1] = tempKeyPrefix;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos * 2] = src[srcPos * 2];
+ dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
+ return new long[length * 2];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
new file mode 100644
index 0000000000000..16ac2e8d821ba
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+
+public abstract class UnsafeSorterIterator {
+
+ public abstract boolean hasNext();
+
+ public abstract void loadNext() throws IOException;
+
+ public abstract Object getBaseObject();
+
+ public abstract long getBaseOffset();
+
+ public abstract int getRecordLength();
+
+ public abstract long getKeyPrefix();
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
new file mode 100644
index 0000000000000..8272c2a5be0d1
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.PriorityQueue;
+
+final class UnsafeSorterSpillMerger {
+
+ private final PriorityQueue priorityQueue;
+
+ public UnsafeSorterSpillMerger(
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ final int numSpills) {
+ final Comparator comparator = new Comparator() {
+
+ @Override
+ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
+ final int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(),
+ right.getBaseObject(), right.getBaseOffset());
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ };
+ priorityQueue = new PriorityQueue(numSpills, comparator);
+ }
+
+ public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ }
+ priorityQueue.add(spillReader);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ return new UnsafeSorterIterator() {
+
+ private UnsafeSorterIterator spillReader;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ if (spillReader != null) {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ priorityQueue.add(spillReader);
+ }
+ }
+ spillReader = priorityQueue.remove();
+ }
+
+ @Override
+ public Object getBaseObject() { return spillReader.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return spillReader.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return spillReader.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return spillReader.getKeyPrefix(); }
+ };
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
new file mode 100644
index 0000000000000..29e9e0f30f934
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.*;
+
+import com.google.common.io.ByteStreams;
+
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
+ * of the file format).
+ */
+final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every record read:
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+
+ public UnsafeSorterSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ recordLength = din.readInt();
+ keyPrefix = din.readLong();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ ByteStreams.readFully(in, arr, 0, recordLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
new file mode 100644
index 0000000000000..71eed29563d4a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Tuple2;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Spills a list of sorted records to disk. Spill files have the following format:
+ *
+ * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
+ */
+final class UnsafeSorterSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private DiskBlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeLong.
+ private void writeLongToBuffer(long v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 56);
+ writeBuffer[offset + 1] = (byte)(v >>> 48);
+ writeBuffer[offset + 2] = (byte)(v >>> 40);
+ writeBuffer[offset + 3] = (byte)(v >>> 32);
+ writeBuffer[offset + 4] = (byte)(v >>> 24);
+ writeBuffer[offset + 5] = (byte)(v >>> 16);
+ writeBuffer[offset + 6] = (byte)(v >>> 8);
+ writeBuffer[offset + 7] = (byte)(v >>> 0);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ * @param keyPrefix a sort key prefix
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength,
+ long keyPrefix) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ writeIntToBuffer(recordLength, 0);
+ writeLongToBuffer(keyPrefix, 4);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ writer.recordWritten();
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
+ return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
index b146f8a784127..689afea64f8db 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
@@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN
log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
+
+# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
+log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
+log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 3a2a88219818f..27006e45e932b 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN
log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
+
+# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
+log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
+log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index 0b450dc76bc38..3c8ddddf07b1e 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -19,6 +19,9 @@
* to be registered after the page loads. */
$(function() {
$("span.expand-additional-metrics").click(function(){
+ var status = window.localStorage.getItem("expand-additional-metrics") == "true";
+ status = !status;
+
// Expand the list of additional metrics.
var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
$(additionalMetricsDiv).toggleClass('collapsed');
@@ -26,17 +29,31 @@ $(function() {
// Switch the class of the arrow from open to closed.
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-additional-metrics", "" + status);
});
+ if (window.localStorage.getItem("expand-additional-metrics") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-additional-metrics", "false");
+ $("span.expand-additional-metrics").trigger("click");
+ }
+
stripeSummaryTable();
$('input[type="checkbox"]').click(function() {
- var column = "table ." + $(this).attr("name");
+ var name = $(this).attr("name")
+ var column = "table ." + name;
+ var status = window.localStorage.getItem(name) == "true";
+ status = !status;
$(column).toggle();
stripeSummaryTable();
+ window.localStorage.setItem(name, "" + status);
});
$("#select-all-metrics").click(function() {
+ var status = window.localStorage.getItem("select-all-metrics") == "true";
+ status = !status;
if (this.checked) {
// Toggle all un-checked options.
$('input[type="checkbox"]:not(:checked)').trigger('click');
@@ -44,6 +61,21 @@ $(function() {
// Toggle all checked options.
$('input[type="checkbox"]:checked').trigger('click');
}
+ window.localStorage.setItem("select-all-metrics", "" + status);
+ });
+
+ if (window.localStorage.getItem("select-all-metrics") == "true") {
+ $("#select-all-metrics").attr('checked', status);
+ }
+
+ $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() {
+ var name = $(this).attr("name")
+ // If name is undefined, then skip it because it's the "select-all-metrics" checkbox
+ if (name && window.localStorage.getItem(name) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(name, "false");
+ $(this).trigger("click")
+ }
});
// Trigger a click on the checkbox if a user clicks the label next to it.
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index 9fa53baaf4212..4a893bc0189aa 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -72,6 +72,14 @@ var StagePageVizConstants = {
rankSep: 40
};
+/*
+ * Return "expand-dag-viz-arrow-job" if forJob is true.
+ * Otherwise, return "expand-dag-viz-arrow-stage".
+ */
+function expandDagVizArrowKey(forJob) {
+ return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage";
+}
+
/*
* Show or hide the RDD DAG visualization.
*
@@ -79,6 +87,9 @@ var StagePageVizConstants = {
* This is the narrow interface called from the Scala UI code.
*/
function toggleDagViz(forJob) {
+ var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true";
+ status = !status;
+
var arrowSelector = ".expand-dag-viz-arrow";
$(arrowSelector).toggleClass('arrow-closed');
$(arrowSelector).toggleClass('arrow-open');
@@ -93,8 +104,24 @@ function toggleDagViz(forJob) {
// Save the graph for later so we don't have to render it again
graphContainer().style("display", "none");
}
+
+ window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status);
}
+$(function (){
+ if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(false), "false");
+ toggleDagViz(false);
+ }
+
+ if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(true), "false");
+ toggleDagViz(true);
+ }
+});
+
/*
* Render the RDD DAG visualization.
*
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
index ca74ef9d7e94e..f4453c71df1ea 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
@@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) {
setupJobEventAction();
$("span.expand-application-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-application-timeline") == "true";
+ status = !status;
+
$("#application-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-application-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-application-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-application-timeline", "false");
+ $("span.expand-application-timeline").trigger('click');
+ }
+});
+
function drawJobTimeline(groupArray, eventObjArray, startTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) {
setupStageEventAction();
$("span.expand-job-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-job-timeline") == "true";
+ status = !status;
+
$("#job-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-job-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-job-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-job-timeline", "false");
+ $("span.expand-job-timeline").trigger('click');
+ }
+});
+
function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma
setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline);
$("span.expand-task-assignment-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true";
+ status = !status;
+
$("#task-assignment-timeline").toggleClass("collapsed");
// Switch the class of the arrow from open to closed.
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open");
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed");
+
+ window.localStorage.setItem("expand-task-assignment-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-task-assignment-timeline", "false");
+ $("span.expand-task-assignment-timeline").trigger('click');
+ }
+});
+
function setupExecutorEventAction() {
$(".item.box.executor").each(function () {
$(this).hover(
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5a8d17bd99933..eb75f26718e19 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -20,7 +20,8 @@ package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
import scala.collection.generic.Growable
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
import scala.ref.WeakReference
import scala.reflect.ClassTag
@@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `R` and `T`
* @param name human-readable name for use in Spark's web UI
+ * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported
+ * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be
+ * thread safe so that they can be reported correctly.
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
-class Accumulable[R, T] (
+class Accumulable[R, T] private[spark] (
@transient initialValue: R,
param: AccumulableParam[R, T],
- val name: Option[String])
+ val name: Option[String],
+ internal: Boolean)
extends Serializable {
+ private[spark] def this(
+ @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
+ this(initialValue, param, None, internal)
+ }
+
+ def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
+ this(initialValue, param, name, false)
+
def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None)
val id: Long = Accumulators.newId
- @transient private var value_ = initialValue // Current value on master
+ @volatile @transient private var value_ : R = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
private var deserialized = false
- Accumulators.register(this, true)
+ Accumulators.register(this)
+
+ /**
+ * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver
+ * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be
+ * reported correctly.
+ */
+ private[spark] def isInternal: Boolean = internal
/**
* Add more data to this accumulator / accumulable
@@ -132,7 +152,8 @@ class Accumulable[R, T] (
in.defaultReadObject()
value_ = zero
deserialized = true
- Accumulators.register(this, false)
+ val taskContext = TaskContext.get()
+ taskContext.registerAccumulator(this)
}
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
*/
- val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
-
- /**
- * This thread-local map holds per-task copies of accumulators; it is used to collect the set
- * of accumulator updates to send back to the driver when tasks complete. After tasks complete,
- * this map is cleared by `Accumulators.clear()` (see Executor.scala).
- */
- private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
- override protected def initialValue() = Map[Long, Accumulable[_, _]]()
- }
+ val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
private var lastId: Long = 0
@@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
lastId
}
- def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
- if (original) {
- originals(a.id) = new WeakReference[Accumulable[_, _]](a)
- } else {
- localAccums.get()(a.id) = a
- }
- }
-
- // Clear the local (non-original) accumulators for the current thread
- def clear() {
- synchronized {
- localAccums.get.clear()
- }
+ def register(a: Accumulable[_, _]): Unit = synchronized {
+ originals(a.id) = new WeakReference[Accumulable[_, _]](a)
}
def remove(accId: Long) {
@@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
}
}
- // Get the values of the local accumulators for the current thread (by ID)
- def values: Map[Long, Any] = synchronized {
- val ret = Map[Long, Any]()
- for ((id, accum) <- localAccums.get) {
- ret(id) = accum.localValue
- }
- return ret
- }
-
// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
@@ -349,7 +341,4 @@ private[spark] object Accumulators extends Logging {
}
}
- def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue)
-
- def stringifyValue(value: Any): String = "%s".format(value)
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index 443830f8d03b6..842bfdbadc948 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -24,11 +24,23 @@ package org.apache.spark
private[spark] trait ExecutorAllocationClient {
/**
- * Express a preference to the cluster manager for a given total number of executors.
- * This can result in canceling pending requests or filing additional requests.
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
* @return whether the request is acknowledged by the cluster manager.
*/
- private[spark] def requestTotalExecutors(numExecutors: Int): Boolean
+ private[spark] def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int]): Boolean
/**
* Request an additional number of executors from the cluster manager.
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 49329423dca76..1877aaf2cac55 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import java.util.concurrent.TimeUnit
import scala.collection.mutable
+import scala.util.control.ControlThrowable
import com.codahale.metrics.{Gauge, MetricRegistry}
@@ -102,7 +103,7 @@ private[spark] class ExecutorAllocationManager(
"spark.dynamicAllocation.executorIdleTimeout", "60s")
private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds(
- "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s")
+ "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s")
// During testing, the methods to actually kill and add executors are mocked out
private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
@@ -160,6 +161,12 @@ private[spark] class ExecutorAllocationManager(
// (2) an executor idle timeout has elapsed.
@volatile private var initializing: Boolean = true
+ // Number of locality aware tasks, used for executor placement.
+ private var localityAwareTasks = 0
+
+ // Host to possible task running on it, used for executor placement.
+ private var hostToLocalTaskCount: Map[String, Int] = Map.empty
+
/**
* Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
@@ -211,7 +218,16 @@ private[spark] class ExecutorAllocationManager(
listenerBus.addListener(listener)
val scheduleTask = new Runnable() {
- override def run(): Unit = Utils.logUncaughtExceptions(schedule())
+ override def run(): Unit = {
+ try {
+ schedule()
+ } catch {
+ case ct: ControlThrowable =>
+ throw ct
+ case t: Throwable =>
+ logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
+ }
+ }
}
executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS)
}
@@ -285,7 +301,7 @@ private[spark] class ExecutorAllocationManager(
// If the new target has not changed, avoid sending a message to the cluster manager
if (numExecutorsTarget < oldNumExecutorsTarget) {
- client.requestTotalExecutors(numExecutorsTarget)
+ client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " +
s"$oldNumExecutorsTarget) because not all requested executors are actually needed")
}
@@ -339,7 +355,8 @@ private[spark] class ExecutorAllocationManager(
return 0
}
- val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget)
+ val addRequestAcknowledged = testing ||
+ client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
if (addRequestAcknowledged) {
val executorsString = "executor" + { if (delta > 1) "s" else "" }
logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" +
@@ -509,6 +526,12 @@ private[spark] class ExecutorAllocationManager(
// Number of tasks currently running on the cluster. Should be 0 when no stages are active.
private var numRunningTasks: Int = _
+ // stageId to tuple (the number of task with locality preferences, a map where each pair is a
+ // node and the number of tasks that would like to be scheduled on that node) map,
+ // maintain the executor placement hints for each stage Id used by resource framework to better
+ // place the executors.
+ private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])]
+
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
initializing = false
val stageId = stageSubmitted.stageInfo.stageId
@@ -516,6 +539,24 @@ private[spark] class ExecutorAllocationManager(
allocationManager.synchronized {
stageIdToNumTasks(stageId) = numTasks
allocationManager.onSchedulerBacklogged()
+
+ // Compute the number of tasks requested by the stage on each host
+ var numTasksPending = 0
+ val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]()
+ stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality =>
+ if (!locality.isEmpty) {
+ numTasksPending += 1
+ locality.foreach { location =>
+ val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1
+ hostToLocalTaskCountPerStage(location.host) = count
+ }
+ }
+ }
+ stageIdToExecutorPlacementHints.put(stageId,
+ (numTasksPending, hostToLocalTaskCountPerStage.toMap))
+
+ // Update the executor placement hints
+ updateExecutorPlacementHints()
}
}
@@ -524,6 +565,10 @@ private[spark] class ExecutorAllocationManager(
allocationManager.synchronized {
stageIdToNumTasks -= stageId
stageIdToTaskIndices -= stageId
+ stageIdToExecutorPlacementHints -= stageId
+
+ // Update the executor placement hints
+ updateExecutorPlacementHints()
// If this is the last stage with pending tasks, mark the scheduler queue as empty
// This is needed in case the stage is aborted for any reason
@@ -627,6 +672,29 @@ private[spark] class ExecutorAllocationManager(
def isExecutorIdle(executorId: String): Boolean = {
!executorIdToTaskIds.contains(executorId)
}
+
+ /**
+ * Update the Executor placement hints (the number of tasks with locality preferences,
+ * a map where each pair is a node and the number of tasks that would like to be scheduled
+ * on that node).
+ *
+ * These hints are updated when stages arrive and complete, so are not up-to-date at task
+ * granularity within stages.
+ */
+ def updateExecutorPlacementHints(): Unit = {
+ var localityAwareTasks = 0
+ val localityToCount = new mutable.HashMap[String, Int]()
+ stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) =>
+ localityAwareTasks += numTasksPending
+ localities.foreach { case (hostname, count) =>
+ val updatedCount = localityToCount.getOrElse(hostname, 0) + count
+ localityToCount(hostname) = updatedCount
+ }
+ }
+
+ allocationManager.localityAwareTasks = localityAwareTasks
+ allocationManager.hostToLocalTaskCount = localityToCount.toMap
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 221b1dab43278..43dd4a170731d 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
// Asynchronously kill the executor to avoid blocking the current thread
killExecutorThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- sc.killExecutor(executorId)
+ // Note: we want to get an executor back after expiring this one,
+ // so do not simply call `sc.killExecutor` here (SPARK-8119)
+ sc.killAndReplaceExecutor(executorId)
}
})
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 7fcb7830e7b0b..f0598816d6c07 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -121,6 +121,7 @@ trait Logging {
if (usingLog4j12) {
val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4j12Initialized) {
+ // scalastyle:off println
if (Utils.isInInterpreter) {
val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties"
Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match {
@@ -141,6 +142,7 @@ trait Logging {
System.err.println(s"Spark was unable to load $defaultLogProps")
}
}
+ // scalastyle:on println
}
}
Logging.initialized = true
@@ -157,7 +159,7 @@ private object Logging {
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
- val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
+ val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler")
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
if (!installed) {
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 862ffe868f58f..92218832d256f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,14 +21,14 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.{HashMap, HashSet, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
@@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
/**
- * Called from executors to get the server URIs and output sizes of the map outputs of
- * a given shuffle.
+ * Called from executors to get the server URIs and output sizes for each shuffle block that
+ * needs to be read from a given reduce task.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
*/
- def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
+ : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
+ val startTime = System.currentTimeMillis
+
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
@@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
}
+ logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " +
+ s"${System.currentTimeMillis - startTime} ms")
+
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
@@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging {
}
}
- // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
- // any of the statuses is null (indicating a missing location due to a failed mapper),
- // throw a FetchFailedException.
+ /**
+ * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block
+ * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that
+ * block manager.
+ *
+ * If any of the statuses is null (indicating a missing location due to a failed mapper),
+ * throws a FetchFailedException.
+ *
+ * @param shuffleId Identifier for the shuffle
+ * @param reduceId Identifier for the reduce task
+ * @param statuses List of map statuses, indexed by map ID.
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
- statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
- statuses.map {
- status =>
- if (status == null) {
- logError("Missing an output location for shuffle " + shuffleId)
- throw new MetadataFetchFailedException(
- shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
- } else {
- (status.location, status.getSizeForBlock(reduceId))
- }
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
+ for ((status, mapId) <- statuses.zipWithIndex) {
+ if (status == null) {
+ val errorMessage = s"Missing an output location for shuffle $shuffleId"
+ logError(errorMessage)
+ throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage)
+ } else {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId)))
+ }
}
+
+ splitsByAddress.toSeq
}
}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 82889bcd30988..4b9d59975bdc2 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -56,7 +56,7 @@ object Partitioner {
*/
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
- for (r <- bySize if r.partitioner.isDefined) {
+ for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
return r.partitioner.get
}
if (rdd.context.conf.contains("spark.default.parallelism")) {
@@ -76,6 +76,8 @@ object Partitioner {
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
+ require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
+
def numPartitions: Int = partitions
def getPartition(key: Any): Int = key match {
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 6cf36fbbd6254..4161792976c7b 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,11 +18,12 @@
package org.apache.spark
import java.util.concurrent.ConcurrentHashMap
-import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable.LinkedHashSet
+import org.apache.avro.{SchemaNormalization, Schema}
+
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils
@@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
this
}
+ private final val avroNamespace = "avro.schema."
+
+ /**
+ * Use Kryo serialization and register the given set of Avro schemas so that the generic
+ * record serializer can decrease network IO
+ */
+ def registerAvroSchemas(schemas: Schema*): SparkConf = {
+ for (schema <- schemas) {
+ set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString)
+ }
+ this
+ }
+
+ /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */
+ def getAvroSchema: Map[Long, String] = {
+ getAll.filter { case (k, v) => k.startsWith(avroNamespace) }
+ .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) }
+ .toMap
+ }
+
/** Remove a parameter from the configuration */
def remove(key: String): SparkConf = {
settings.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index d2547eeff2b4e..ac6ac6c216767 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -471,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
.orElse(Option(System.getenv("SPARK_MEM"))
.map(warnSparkMem))
.map(Utils.memoryStringToMb)
- .getOrElse(512)
+ .getOrElse(1024)
// Convert java options to env vars as a work around
// since we can't set env vars directly in sbt.
@@ -532,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_executorAllocationManager =
if (dynamicAllocationEnabled) {
assert(supportDynamicAllocation,
- "Dynamic allocation of executors is currently only supported in YARN mode")
+ "Dynamic allocation of executors is currently only supported in YARN and Mesos mode")
Some(new ExecutorAllocationManager(this, listenerBus, _conf))
} else {
None
@@ -853,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
minPartitions).setName(path)
}
-
/**
* :: Experimental ::
*
@@ -1364,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* Return whether dynamically adjusting the amount of resources allocated to
- * this application is supported. This is currently only available for YARN.
+ * this application is supported. This is currently only available for YARN
+ * and Mesos coarse-grained mode.
*/
- private[spark] def supportDynamicAllocation =
- master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false)
+ private[spark] def supportDynamicAllocation: Boolean = {
+ (master.contains("yarn")
+ || master.contains("mesos")
+ || _conf.getBoolean("spark.dynamicAllocation.testing", false))
+ }
/**
* :: DeveloperApi ::
@@ -1379,16 +1382,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Express a preference to the cluster manager for a given total number of executors.
- * This can result in canceling pending requests or filing additional requests.
- * This is currently only supported in YARN mode. Return whether the request is received.
- */
- private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = {
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
+ * @return whether the request is acknowledged by the cluster manager.
+ */
+ private[spark] override def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: scala.collection.immutable.Map[String, Int]
+ ): Boolean = {
assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN mode")
+ "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.requestTotalExecutors(numExecutors)
+ b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount)
case _ =>
logWarning("Requesting executors is only supported in coarse-grained mode")
false
@@ -1403,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN mode")
+ "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestExecutors(numAdditionalExecutors)
@@ -1416,12 +1432,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
+ *
+ * Note: This is an indication to the cluster manager that the application wishes to adjust
+ * its resource usage downwards. If the application wishes to replace the executors it kills
+ * through this method with new ones, it should follow up explicitly with a call to
+ * {{SparkContext#requestExecutors}}.
+ *
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
assert(supportDynamicAllocation,
- "Killing executors is currently only supported in YARN mode")
+ "Killing executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(executorIds)
@@ -1433,12 +1455,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
- * Request that cluster manager the kill the specified executor.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * Request that the cluster manager kill the specified executor.
+ *
+ * Note: This is an indication to the cluster manager that the application wishes to adjust
+ * its resource usage downwards. If the application wishes to replace the executor it kills
+ * through this method with a new one, it should follow up explicitly with a call to
+ * {{SparkContext#requestExecutors}}.
+ *
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
+ /**
+ * Request that the cluster manager kill the specified executor without adjusting the
+ * application resource requirements.
+ *
+ * The effect is that a new executor will be launched in place of the one killed by
+ * this request. This assumes the cluster manager will automatically and eventually
+ * fulfill all missing application resource requests.
+ *
+ * Note: The replace is by no means guaranteed; another application on the same cluster
+ * can steal the window of opportunity and acquire this application's resources in the
+ * mean time.
+ *
+ * This is currently only supported in YARN mode. Return whether the request is received.
+ */
+ private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(Seq(executorId), replace = true)
+ case _ =>
+ logWarning("Killing executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
/** The version of Spark on which this application is running. */
def version: String = SPARK_VERSION
@@ -1719,16 +1771,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
- * handler function. This is the main entry point for all actions in Spark. The allowLocal
- * flag specifies whether the scheduler can run the computation on the driver rather than
- * shipping it out to the cluster, for short actions like first().
+ * handler function. This is the main entry point for all actions in Spark.
*/
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean,
- resultHandler: (Int, U) => Unit) {
+ resultHandler: (Int, U) => Unit): Unit = {
if (stopped.get()) {
throw new IllegalStateException("SparkContext has been shutdown")
}
@@ -1738,54 +1787,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
if (conf.getBoolean("spark.logLineage", false)) {
logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
}
- dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
- resultHandler, localProperties.get)
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}
/**
- * Run a function on a given set of partitions in an RDD and return the results as an array. The
- * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
- * than shipping it out to the cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and return the results as an array.
+ */
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int]): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
+ * Run a job on a given set of partitions of an RDD, but take a function of type
+ * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ */
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: Iterator[T] => U,
+ partitions: Seq[Int]): Array[U] = {
+ val cleanedFunc = clean(func)
+ runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions)
+ }
+
+
+ /**
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark.
+ *
+ * The allowLocal flag is deprecated as of Spark 1.5.0+.
+ */
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit): Unit = {
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions, resultHandler)
+ }
+
+ /**
+ * Run a function on a given set of partitions in an RDD and return the results as an array.
+ *
+ * The allowLocal flag is deprecated as of Spark 1.5.0+.
*/
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- val results = new Array[U](partitions.size)
- runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
- results
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions)
}
/**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ *
+ * The allowLocal argument is deprecated as of Spark 1.5.0+.
*/
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- val cleanedFunc = clean(func)
- runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal)
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.partitions.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.length)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.partitions.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.length)
}
/**
@@ -1796,7 +1895,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
- runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
+ runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler)
}
/**
@@ -1808,7 +1907,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
- runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
+ runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler)
}
/**
@@ -1853,7 +1952,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
partitions,
callSite,
- allowLocal = false,
resultHandler,
localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
@@ -1965,7 +2063,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
for (className <- listenerClassNames) {
// Use reflection to find the right constructor
val constructors = {
- val listenerClass = Class.forName(className)
+ val listenerClass = Utils.classForName(className)
listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
}
val constructorTakingSparkConf = constructors.find { c =>
@@ -2500,7 +2598,7 @@ object SparkContext extends Logging {
"\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
}
val scheduler = try {
- val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
@@ -2512,7 +2610,7 @@ object SparkContext extends Logging {
}
val backend = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
+ Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
@@ -2525,8 +2623,7 @@ object SparkContext extends Logging {
case "yarn-client" =>
val scheduler = try {
- val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
+ val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
@@ -2538,7 +2635,7 @@ object SparkContext extends Logging {
val backend = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index d18fc599e9890..adfece4d6e7c0 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -261,7 +261,7 @@ object SparkEnv extends Logging {
// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
- val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
+ val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index d09e17dea0911..b48836d5c8897 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -32,7 +33,20 @@ object TaskContext {
*/
def get(): TaskContext = taskContext.get
- private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
+ /**
+ * Returns the partition id of currently active TaskContext. It will return 0
+ * if there is no active TaskContext for cases like local execution.
+ */
+ def getPartitionId(): Int = {
+ val tc = taskContext.get()
+ if (tc eq null) {
+ 0
+ } else {
+ tc.partitionId()
+ }
+ }
+
+ private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
// Note: protected[spark] instead of private[spark] to prevent the following two from
// showing up in JavaDoc.
@@ -135,8 +149,34 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def taskMetrics(): TaskMetrics
+ /**
+ * ::DeveloperApi::
+ * Returns all metrics sources with the given name which are associated with the instance
+ * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]].
+ */
+ @DeveloperApi
+ def getMetricsSources(sourceName: String): Seq[Source]
+
/**
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager
+
+ /**
+ * Register an accumulator that belongs to this task. Accumulators must call this method when
+ * deserializing in executors.
+ */
+ private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+
+ /**
+ * Return the local values of internal accumulators that belong to this task. The key of the Map
+ * is the accumulator id and the value of the Map is the latest accumulator local value.
+ */
+ private[spark] def collectInternalAccumulators(): Map[Long, Any]
+
+ /**
+ * Return the local values of accumulators that belong to this task. The key of the Map is the
+ * accumulator id and the value of the Map is the latest accumulator local value.
+ */
+ private[spark] def collectAccumulators(): Map[Long, Any]
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index b4d572cb52313..9ee168ae016f8 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,18 +17,21 @@
package org.apache.spark
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
-import scala.collection.mutable.ArrayBuffer
-
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ @transient private val metricsSystem: MetricsSystem,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -94,5 +97,21 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = runningLocally
override def isInterrupted(): Boolean = interrupted
-}
+ override def getMetricsSources(sourceName: String): Seq[Source] =
+ metricsSystem.getSourcesByName(sourceName)
+
+ @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
+
+ private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
+ accumulators(a.id) = a
+ }
+
+ private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
+ accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
+ }
+
+ private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
+ accumulators.mapValues(_.localValue).toMap
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index c95615a5a9307..829fae1d1d9bf 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -364,7 +364,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
// This is useful for implementing `take` from other language frontends
// like Python where the data is serialized.
import scala.collection.JavaConversions._
- val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds)
res.map(x => new java.util.ArrayList(x.toSeq)).toArray
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index dc9f62f39e6d5..55e563ee968be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -207,6 +207,7 @@ private[spark] class PythonRDD(
override def run(): Unit = Utils.logUncaughtExceptions {
try {
+ TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
@@ -263,11 +264,6 @@ private[spark] class PythonRDD(
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
- } finally {
- // Release memory used by this thread for shuffles
- env.shuffleMemoryManager.releaseMemoryForThisThread()
- // Release memory used by this thread for unrolling blocks
- env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
@@ -358,12 +354,11 @@ private[spark] object PythonRDD extends Logging {
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
- partitions: JArrayList[Int],
- allowLocal: Boolean): Int = {
+ partitions: JArrayList[Int]): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
- sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
+ sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 1a5f2bca26c2b..b7e72d4d0ed0b 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -95,7 +95,9 @@ private[spark] class RBackend {
private[spark] object RBackend extends Logging {
def main(args: Array[String]): Unit = {
if (args.length < 1) {
+ // scalastyle:off println
System.err.println("Usage: RBackend ")
+ // scalastyle:on println
System.exit(-1)
}
val sparkRBackend = new RBackend()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 4b8f7fe9242e0..14dac4ed28ce3 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -20,12 +20,14 @@ package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.collection.mutable.HashMap
+import scala.language.existentials
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.Logging
import org.apache.spark.api.r.SerDe._
+import org.apache.spark.util.Utils
/**
* Handler for RBackend
@@ -67,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend)
case e: Exception =>
logError(s"Removing $objId failed", e)
writeInt(dos, -1)
+ writeString(dos, s"Removing $objId failed: ${e.getMessage}")
}
- case _ => dos.writeInt(-1)
+ case _ =>
+ dos.writeInt(-1)
+ writeString(dos, s"Error: unknown method $methodName")
}
} else {
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
@@ -88,21 +93,6 @@ private[r] class RBackendHandler(server: RBackend)
ctx.close()
}
- // Looks up a class given a class name. This function first checks the
- // current class loader and if a class is not found, it looks up the class
- // in the context class loader. Address [SPARK-5185]
- def getStaticClass(objId: String): Class[_] = {
- try {
- val clsCurrent = Class.forName(objId)
- clsCurrent
- } catch {
- // Use contextLoader if we can't find the JAR in the system class loader
- case e: ClassNotFoundException =>
- val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader)
- clsContext
- }
- }
-
def handleMethodCall(
isStatic: Boolean,
objId: String,
@@ -113,7 +103,7 @@ private[r] class RBackendHandler(server: RBackend)
var obj: Object = null
try {
val cls = if (isStatic) {
- getStaticClass(objId)
+ Utils.classForName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
@@ -159,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend)
}
} catch {
case e: Exception =>
- logError(s"$methodName on $objId failed", e)
+ logError(s"$methodName on $objId failed")
writeInt(dos, -1)
+ // Writing the error message of the cause for the exception. This will be returned
+ // to user in the R process.
+ writeString(dos, Utils.exceptionString(e.getCause))
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 524676544d6f5..1cf2824f862ee 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
protected var dataStream: DataInputStream = _
@@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The stdout/stderr is shared by multiple tasks, because we use one daemon
// to launch child process as worker.
- val errThread = RRDD.createRWorker(rLibDir, listenPort)
+ val errThread = RRDD.createRWorker(listenPort)
// We use two sockets to separate input and output, then it's easy to manage
// the lifecycle of them to avoid deadlock.
@@ -113,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
partition: Int): Unit = {
val env = SparkEnv.get
+ val taskContext = TaskContext.get()
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val stream = new BufferedOutputStream(output, bufferSize)
@@ -120,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
override def run(): Unit = {
try {
SparkEnv.set(env)
+ TaskContext.setTaskContext(taskContext)
val dataOut = new DataOutputStream(stream)
dataOut.writeInt(partition)
@@ -161,7 +162,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
dataOut.write(elem.asInstanceOf[Array[Byte]])
} else if (deserializer == SerializationFormats.STRING) {
// write string(for StringRRDD)
+ // scalastyle:off println
printOut.println(elem)
+ // scalastyle:on println
}
}
@@ -233,11 +236,10 @@ private class PairwiseRRDD[T: ClassTag](
hashFunc: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, (Int, Array[Byte])](
parent, numPartitions, hashFunc, deserializer,
- SerializationFormats.BYTE, packageNames, rLibDir,
+ SerializationFormats.BYTE, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): (Int, Array[Byte]) = {
@@ -264,10 +266,9 @@ private class RRDD[T: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, Array[Byte]](
- parent, -1, func, deserializer, serializer, packageNames, rLibDir,
+ parent, -1, func, deserializer, serializer, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): Array[Byte] = {
@@ -291,10 +292,9 @@ private class StringRRDD[T: ClassTag](
func: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, String](
- parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
+ parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): String = {
@@ -390,9 +390,10 @@ private[r] object RRDD {
thread
}
- private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
+ private def createRProcess(port: Int, script: String): BufferedStreamThread = {
val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
val rOptions = "--vanilla"
+ val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir + "/SparkR/worker/" + script
val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
// Unset the R_TESTS environment variable for workers.
@@ -411,7 +412,7 @@ private[r] object RRDD {
/**
* ProcessBuilder used to launch worker R processes.
*/
- def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
+ def createRWorker(port: Int): BufferedStreamThread = {
val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
if (!Utils.isWindows && useDaemon) {
synchronized {
@@ -419,7 +420,7 @@ private[r] object RRDD {
// we expect one connections
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val daemonPort = serverSocket.getLocalPort
- errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
+ errThread = createRProcess(daemonPort, "daemon.R")
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
val sock = serverSocket.accept()
@@ -441,7 +442,7 @@ private[r] object RRDD {
errThread
}
} else {
- createRProcess(rLibDir, port, "worker.R")
+ createRProcess(port, "worker.R")
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
new file mode 100644
index 0000000000000..d53abd3408c55
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.File
+
+import org.apache.spark.{SparkEnv, SparkException}
+
+private[spark] object RUtils {
+ /**
+ * Get the SparkR package path in the local spark distribution.
+ */
+ def localSparkRPackagePath: Option[String] = {
+ val sparkHome = sys.env.get("SPARK_HOME")
+ sparkHome.map(
+ Seq(_, "R", "lib").mkString(File.separator)
+ )
+ }
+
+ /**
+ * Get the SparkR package path in various deployment modes.
+ * This assumes that Spark properties `spark.master` and `spark.submit.deployMode`
+ * and environment variable `SPARK_HOME` are set.
+ */
+ def sparkRPackagePath(isDriver: Boolean): String = {
+ val (master, deployMode) =
+ if (isDriver) {
+ (sys.props("spark.master"), sys.props("spark.submit.deployMode"))
+ } else {
+ val sparkConf = SparkEnv.get.conf
+ (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode"))
+ }
+
+ val isYarnCluster = master.contains("yarn") && deployMode == "cluster"
+ val isYarnClient = master.contains("yarn") && deployMode == "client"
+
+ // In YARN mode, the SparkR package is distributed as an archive symbolically
+ // linked to the "sparkr" file in the current directory. Note that this does not apply
+ // to the driver in client mode because it is run outside of the cluster.
+ if (isYarnCluster || (isYarnClient && !isDriver)) {
+ new File("sparkr").getAbsolutePath
+ } else {
+ // Otherwise, assume the package is local
+ // TODO: support this for Mesos
+ localSparkRPackagePath.getOrElse {
+ throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 56adc857d4ce0..d5b4260bf4529 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -179,6 +179,7 @@ private[spark] object SerDe {
// Int -> integer
// String -> character
// Boolean -> logical
+ // Float -> double
// Double -> double
// Long -> double
// Array[Byte] -> raw
@@ -215,6 +216,9 @@ private[spark] object SerDe {
case "long" | "java.lang.Long" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Long].toDouble)
+ case "float" | "java.lang.Float" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Float].toDouble)
case "double" | "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index 685313ac009ba..fac6666bb3410 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
private[spark] class BroadcastManager(
val isDriver: Boolean,
@@ -42,7 +43,7 @@ private[spark] class BroadcastManager(
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
broadcastFactory =
- Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+ Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 71f7e2129116f..f03875a3e8c89 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -118,26 +118,26 @@ private class ClientEndpoint(
def pollAndReportStatus(driverId: String) {
// Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread
// is fine.
- println("... waiting before polling master for driver state")
+ logInfo("... waiting before polling master for driver state")
Thread.sleep(5000)
- println("... polling master for driver state")
+ logInfo("... polling master for driver state")
val statusResponse =
activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId))
statusResponse.found match {
case false =>
- println(s"ERROR: Cluster master did not recognize $driverId")
+ logError(s"ERROR: Cluster master did not recognize $driverId")
System.exit(-1)
case true =>
- println(s"State of $driverId is ${statusResponse.state.get}")
+ logInfo(s"State of $driverId is ${statusResponse.state.get}")
// Worker node, if present
(statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match {
case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) =>
- println(s"Driver running on $hostPort ($id)")
+ logInfo(s"Driver running on $hostPort ($id)")
case _ =>
}
// Exception, if present
statusResponse.exception.map { e =>
- println(s"Exception from cluster was: $e")
+ logError(s"Exception from cluster was: $e")
e.printStackTrace()
System.exit(-1)
}
@@ -148,7 +148,7 @@ private class ClientEndpoint(
override def receive: PartialFunction[Any, Unit] = {
case SubmitDriverResponse(master, success, driverId, message) =>
- println(message)
+ logInfo(message)
if (success) {
activeMasterEndpoint = master
pollAndReportStatus(driverId.get)
@@ -158,7 +158,7 @@ private class ClientEndpoint(
case KillDriverResponse(master, driverId, success, message) =>
- println(message)
+ logInfo(message)
if (success) {
activeMasterEndpoint = master
pollAndReportStatus(driverId)
@@ -169,13 +169,13 @@ private class ClientEndpoint(
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master $remoteAddress.")
+ logError(s"Error connecting to master $remoteAddress.")
lostMasters += remoteAddress
// Note that this heuristic does not account for the fact that a Master can recover within
// the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
// is not currently a concern, however, because this client does not retry submissions.
if (lostMasters.size >= masterEndpoints.size) {
- println("No master is available, exiting.")
+ logError("No master is available, exiting.")
System.exit(-1)
}
}
@@ -183,18 +183,18 @@ private class ClientEndpoint(
override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master ($remoteAddress).")
- println(s"Cause was: $cause")
+ logError(s"Error connecting to master ($remoteAddress).")
+ logError(s"Cause was: $cause")
lostMasters += remoteAddress
if (lostMasters.size >= masterEndpoints.size) {
- println("No master is available, exiting.")
+ logError("No master is available, exiting.")
System.exit(-1)
}
}
}
override def onError(cause: Throwable): Unit = {
- println(s"Error processing messages, exiting.")
+ logError(s"Error processing messages, exiting.")
cause.printStackTrace()
System.exit(-1)
}
@@ -209,10 +209,12 @@ private class ClientEndpoint(
*/
object Client {
def main(args: Array[String]) {
+ // scalastyle:off println
if (!sys.props.contains("SPARK_SUBMIT")) {
println("WARNING: This client is deprecated and will be removed in a future version of Spark")
println("Use ./bin/spark-submit with \"--master spark://host:port\"")
}
+ // scalastyle:on println
val conf = new SparkConf()
val driverArgs = new ClientArguments(args)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 42d3296062e6d..72cc330a398da 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) {
cmd = "launch"
if (!ClientArguments.isValidJarUrl(_jarUrl)) {
+ // scalastyle:off println
println(s"Jar url '${_jarUrl}' is not in valid format.")
println(s"Must be a jar file path in URL format " +
"(e.g. hdfs://host:port/XX.jar, file:///XX.jar)")
+ // scalastyle:on println
printUsageAndExit(-1)
}
@@ -110,7 +112,9 @@ private[deploy] class ClientArguments(args: Array[String]) {
| (default: $DEFAULT_SUPERVISE)
| -v, --verbose Print more debugging output
""".stripMargin
+ // scalastyle:off println
System.err.println(usage)
+ // scalastyle:on println
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 2954f932b4f41..ccffb36652988 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -76,12 +76,13 @@ private[deploy] object JsonProtocol {
}
def writeMasterState(obj: MasterStateResponse): JObject = {
+ val aliveWorkers = obj.workers.filter(_.isAlive())
("url" -> obj.uri) ~
("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~
- ("cores" -> obj.workers.map(_.cores).sum) ~
- ("coresused" -> obj.workers.map(_.coresUsed).sum) ~
- ("memory" -> obj.workers.map(_.memory).sum) ~
- ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
+ ("cores" -> aliveWorkers.map(_.cores).sum) ~
+ ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~
+ ("memory" -> aliveWorkers.map(_.memory).sum) ~
+ ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index e99779f299785..c0cab22fa8252 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
-import org.apache.spark.api.r.RBackend
+import org.apache.spark.api.r.{RBackend, RUtils}
import org.apache.spark.util.RedirectThread
/**
@@ -71,9 +71,10 @@ object RRunner {
val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
- val sparkHome = System.getenv("SPARK_HOME")
+ val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
+ env.put("SPARKR_PACKAGE_DIR", rPackageDir)
env.put("R_PROFILE_USER",
- Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator))
+ Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
@@ -85,7 +86,9 @@ object RRunner {
}
System.exit(returnCode)
} else {
+ // scalastyle:off println
System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds")
+ // scalastyle:on println
System.exit(-1)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 6d14590a1d192..e06b06e06fb4a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -25,6 +25,7 @@ import java.util.{Arrays, Comparator}
import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.control.NonFatal
import com.google.common.primitives.Longs
import org.apache.hadoop.conf.Configuration
@@ -178,7 +179,7 @@ class SparkHadoopUtil extends Logging {
private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
val statisticsDataClass =
- Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
statisticsDataClass.getDeclaredMethod(methodName)
}
@@ -238,6 +239,14 @@ class SparkHadoopUtil extends Logging {
}.getOrElse(Seq.empty[Path])
}
+ def globPathIfNecessary(pattern: Path): Seq[Path] = {
+ if (pattern.toString.exists("{}[]*?\\".toSet.contains)) {
+ globPath(pattern)
+ } else {
+ Seq(pattern)
+ }
+ }
+
/**
* Lists all the files in a directory with the specified prefix, and does not end with the
* given suffix. The returned {{FileStatus}} instances are sorted by the modification times of
@@ -248,19 +257,25 @@ class SparkHadoopUtil extends Logging {
dir: Path,
prefix: String,
exclusionSuffix: String): Array[FileStatus] = {
- val fileStatuses = remoteFs.listStatus(dir,
- new PathFilter {
- override def accept(path: Path): Boolean = {
- val name = path.getName
- name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
+ try {
+ val fileStatuses = remoteFs.listStatus(dir,
+ new PathFilter {
+ override def accept(path: Path): Boolean = {
+ val name = path.getName
+ name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
+ }
+ })
+ Arrays.sort(fileStatuses, new Comparator[FileStatus] {
+ override def compare(o1: FileStatus, o2: FileStatus): Int = {
+ Longs.compare(o1.getModificationTime, o2.getModificationTime)
}
})
- Arrays.sort(fileStatuses, new Comparator[FileStatus] {
- override def compare(o1: FileStatus, o2: FileStatus): Int = {
- Longs.compare(o1.getModificationTime, o2.getModificationTime)
- }
- })
- fileStatuses
+ fileStatuses
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error while attempting to list files from application staging dir", e)
+ Array.empty
+ }
}
/**
@@ -356,7 +371,7 @@ object SparkHadoopUtil {
System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
+ Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
.newInstance()
.asInstanceOf[SparkHadoopUtil]
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index b1d6ec209d62b..0b39ee8fe3ba0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver}
+import org.apache.spark.api.r.RUtils
import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -79,9 +80,11 @@ object SparkSubmit {
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
private val SPARKR_SHELL = "sparkr-shell"
+ private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
+ // scalastyle:off println
// Exposed for testing
private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
private[spark] var printStream: PrintStream = System.err
@@ -102,11 +105,14 @@ object SparkSubmit {
printStream.println("Type --help for more information.")
exitFn(0)
}
+ // scalastyle:on println
def main(args: Array[String]): Unit = {
val appArgs = new SparkSubmitArguments(args)
if (appArgs.verbose) {
+ // scalastyle:off println
printStream.println(appArgs)
+ // scalastyle:on println
}
appArgs.action match {
case SparkSubmitAction.SUBMIT => submit(appArgs)
@@ -160,7 +166,9 @@ object SparkSubmit {
// makes the message printed to the output by the JVM not very helpful. Instead,
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
+ // scalastyle:off println
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
+ // scalastyle:on println
exitFn(1)
} else {
throw e
@@ -178,7 +186,9 @@ object SparkSubmit {
// to use the legacy gateway if the master endpoint turns out to be not a REST server.
if (args.isStandaloneCluster && args.useRest) {
try {
+ // scalastyle:off println
printStream.println("Running Spark using the REST application submission protocol.")
+ // scalastyle:on println
doRunMain()
} catch {
// Fail over to use the legacy submission gateway
@@ -254,6 +264,12 @@ object SparkSubmit {
}
}
+ // Update args.deployMode if it is null. It will be passed down as a Spark property later.
+ (args.deployMode, deployMode) match {
+ case (null, CLIENT) => args.deployMode = "client"
+ case (null, CLUSTER) => args.deployMode = "cluster"
+ case _ =>
+ }
val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER
@@ -339,6 +355,23 @@ object SparkSubmit {
}
}
+ // In YARN mode for an R app, add the SparkR package archive to archives
+ // that can be distributed with the job
+ if (args.isR && clusterManager == YARN) {
+ val rPackagePath = RUtils.localSparkRPackagePath
+ if (rPackagePath.isEmpty) {
+ printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.")
+ }
+ val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE)
+ if (!rPackageFile.exists()) {
+ printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.")
+ }
+ val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath)
+
+ // Assigns a symbol link name "sparkr" to the shipped package.
+ args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr")
+ }
+
// If we're running a R app, set the main class to our specific R runner
if (args.isR && deployMode == CLIENT) {
if (args.primaryResource == SPARKR_SHELL) {
@@ -367,6 +400,8 @@ object SparkSubmit {
// All cluster managers
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
+ OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
+ sysProp = "spark.submit.deployMode"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
@@ -473,8 +508,14 @@ object SparkSubmit {
}
// Let YARN know it's a pyspark app, so it distributes needed libraries.
- if (clusterManager == YARN && args.isPython) {
- sysProps.put("spark.yarn.isPython", "true")
+ if (clusterManager == YARN) {
+ if (args.isPython) {
+ sysProps.put("spark.yarn.isPython", "true")
+ }
+ if (args.principal != null) {
+ require(args.keytab != null, "Keytab must be specified when the keytab is specified")
+ UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab)
+ }
}
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
@@ -558,6 +599,7 @@ object SparkSubmit {
sysProps: Map[String, String],
childMainClass: String,
verbose: Boolean): Unit = {
+ // scalastyle:off println
if (verbose) {
printStream.println(s"Main class:\n$childMainClass")
printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
@@ -565,6 +607,7 @@ object SparkSubmit {
printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}")
printStream.println("\n")
}
+ // scalastyle:on println
val loader =
if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
@@ -587,13 +630,15 @@ object SparkSubmit {
var mainClass: Class[_] = null
try {
- mainClass = Class.forName(childMainClass, true, loader)
+ mainClass = Utils.classForName(childMainClass)
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
+ // scalastyle:off println
printStream.println(s"Failed to load main class $childMainClass.")
printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.")
+ // scalastyle:on println
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
@@ -766,7 +811,9 @@ private[spark] object SparkSubmitUtils {
brr.setRoot(repo)
brr.setName(s"repo-${i + 1}")
cr.add(brr)
+ // scalastyle:off println
printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
+ // scalastyle:on println
}
}
@@ -829,7 +876,9 @@ private[spark] object SparkSubmitUtils {
val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
val dd = new DefaultDependencyDescriptor(ri, false, false)
dd.addDependencyConfiguration(ivyConfName, ivyConfName)
+ // scalastyle:off println
printStream.println(s"${dd.getDependencyId} added as a dependency")
+ // scalastyle:on println
md.addDependency(dd)
}
}
@@ -896,9 +945,11 @@ private[spark] object SparkSubmitUtils {
ivySettings.setDefaultCache(new File(alternateIvyCache, "cache"))
new File(alternateIvyCache, "jars")
}
+ // scalastyle:off println
printStream.println(
s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}")
printStream.println(s"The jars for the packages stored in: $packagesDirectory")
+ // scalastyle:on println
// create a pattern matcher
ivySettings.addMatcher(new GlobPatternMatcher)
// create the dependency resolvers
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 73ab18332feb4..b3710073e330c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
+ // scalastyle:off println
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
Utils.getPropertiesFromFile(filename).foreach { case (k, v) =>
@@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
}
}
+ // scalastyle:on println
defaultProperties
}
@@ -162,6 +164,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
.orNull
executorCores = Option(executorCores)
.orElse(sparkProperties.get("spark.executor.cores"))
+ .orElse(env.get("SPARK_EXECUTOR_CORES"))
.orNull
totalExecutorCores = Option(totalExecutorCores)
.orElse(sparkProperties.get("spark.cores.max"))
@@ -451,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
}
private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
+ // scalastyle:off println
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
@@ -540,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
outStream.println("CLI options:")
outStream.println(getSqlShellOptions())
}
+ // scalastyle:on println
SparkSubmit.exitFn(exitCode)
}
@@ -571,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
System.setSecurityManager(sm)
try {
- Class.forName(mainClass).getMethod("main", classOf[Array[String]])
+ Utils.classForName(mainClass).getMethod("main", classOf[Array[String]])
.invoke(null, Array(HELP))
} catch {
case e: InvocationTargetException =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
index c5ac45c6730d3..a98b1fa8f83a1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
@@ -19,7 +19,9 @@ package org.apache.spark.deploy.client
private[spark] object TestExecutor {
def main(args: Array[String]) {
+ // scalastyle:off println
println("Hello world!")
+ // scalastyle:on println
while (true) {
Thread.sleep(1000)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2cc465e55fceb..e3060ac3fa1a9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -407,8 +407,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
/**
* Comparison function that defines the sort order for application attempts within the same
- * application. Order is: running attempts before complete attempts, running attempts sorted
- * by start time, completed attempts sorted by end time.
+ * application. Order is: attempts are sorted by descending start time.
+ * Most recent attempt state matches with current state of the app.
*
* Normally applications should have a single running attempt; but failure to call sc.stop()
* may cause multiple running attempts to show up.
@@ -418,11 +418,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
private def compareAttemptInfo(
a1: FsApplicationAttemptInfo,
a2: FsApplicationAttemptInfo): Boolean = {
- if (a1.completed == a2.completed) {
- if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime
- } else {
- !a1.completed
- }
+ a1.startTime >= a2.startTime
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 10638afb74900..a076a9c3f984d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -228,7 +228,7 @@ object HistoryServer extends Logging {
val providerName = conf.getOption("spark.history.provider")
.getOrElse(classOf[FsHistoryProvider].getName())
- val provider = Class.forName(providerName)
+ val provider = Utils.classForName(providerName)
.getConstructor(classOf[SparkConf])
.newInstance(conf)
.asInstanceOf[ApplicationHistoryProvider]
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 4692d22651c93..18265df9faa2c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
Utils.loadDefaultSparkProperties(conf, propertiesFile)
private def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"""
|Usage: HistoryServer [options]
@@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
| spark.history.fs.updateInterval How often to reload log data from storage
| (in seconds, default: 10)
|""".stripMargin)
+ // scalastyle:on println
System.exit(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index f459ed5b3a1a1..aa379d4cd61e7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -21,9 +21,8 @@ import java.io._
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.Logging
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils
@@ -32,11 +31,11 @@ import org.apache.spark.util.Utils
* Files are deleted when applications and workers are removed.
*
* @param dir Directory to store files. Created if non-existent (but not recursively).
- * @param serialization Used to serialize our objects.
+ * @param serializer Used to serialize our objects.
*/
private[master] class FileSystemPersistenceEngine(
val dir: String,
- val serialization: Serialization)
+ val serializer: Serializer)
extends PersistenceEngine with Logging {
new File(dir).mkdir()
@@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine(
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- val out = new FileOutputStream(file)
+ val fileOut = new FileOutputStream(file)
+ var out: SerializationStream = null
Utils.tryWithSafeFinally {
- out.write(serialized)
+ out = serializer.newInstance().serializeStream(fileOut)
+ out.writeObject(value)
} {
- out.close()
+ fileOut.close()
+ if (out != null) {
+ out.close()
+ }
}
}
private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
- val fileData = new Array[Byte](file.length().asInstanceOf[Int])
- val dis = new DataInputStream(new FileInputStream(file))
+ val fileIn = new FileInputStream(file)
+ var in: DeserializationStream = null
try {
- dis.readFully(fileData)
+ in = serializer.newInstance().deserializeStream(fileIn)
+ in.readObject[T]()
} finally {
- dis.close()
+ fileIn.close()
+ if (in != null) {
+ in.close()
+ }
}
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 48070768f6edb..51b3f0dead73e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.language.postfixOps
import scala.util.Random
-import akka.serialization.Serialization
-import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path
-import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
@@ -58,9 +56,6 @@ private[master] class Master(
private val forwardMessageThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
- // TODO Remove it once we don't use akka.serialization.Serialization
- private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
-
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
@@ -161,20 +156,21 @@ private[master] class Master(
masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ val serializer = new JavaSerializer(conf)
val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
- new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new ZooKeeperRecoveryModeFactory(conf, serializer)
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
- new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new FileSystemRecoveryModeFactory(conf, serializer)
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
- val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
- val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
- .newInstance(conf, SerializationExtension(actorSystem))
+ val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
+ .newInstance(conf, serializer)
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
@@ -213,7 +209,7 @@ private[master] class Master(
override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
- val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
} else {
@@ -545,6 +541,7 @@ private[master] class Master(
/**
* Schedule executors to be launched on the workers.
+ * Returns an array containing number of cores assigned to each worker.
*
* There are two modes of launching executors. The first attempts to spread out an application's
* executors on as many workers as possible, while the second does the opposite (i.e. launch them
@@ -555,39 +552,77 @@ private[master] class Master(
* multiple executors from the same application may be launched on the same worker if the worker
* has enough cores and memory. Otherwise, each executor grabs all the cores available on the
* worker by default, in which case only one executor may be launched on each worker.
+ *
+ * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core
+ * at a time). Consider the following example: cluster has 4 workers with 16 cores each.
+ * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is
+ * allocated at a time, 12 cores from each worker would be assigned to each executor.
+ * Since 12 < 16, no executors would launch [SPARK-8881].
*/
- private def startExecutorsOnWorkers(): Unit = {
- // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
- // in the queue, then the second app, etc.
- if (spreadOutApps) {
- // Try to spread out each app among all the workers, until it has all its cores
- for (app <- waitingApps if app.coresLeft > 0) {
- val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
- .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB &&
- worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1))
- .sortBy(_.coresFree).reverse
- val numUsable = usableWorkers.length
- val assigned = new Array[Int](numUsable) // Number of cores to give on each node
- var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
- var pos = 0
- while (toAssign > 0) {
- if (usableWorkers(pos).coresFree - assigned(pos) > 0) {
- toAssign -= 1
- assigned(pos) += 1
+ private def scheduleExecutorsOnWorkers(
+ app: ApplicationInfo,
+ usableWorkers: Array[WorkerInfo],
+ spreadOutApps: Boolean): Array[Int] = {
+ // If the number of cores per executor is not specified, then we can just schedule
+ // 1 core at a time since we expect a single executor to be launched on each worker
+ val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1)
+ val memoryPerExecutor = app.desc.memoryPerExecutorMB
+ val numUsable = usableWorkers.length
+ val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker
+ val assignedMemory = new Array[Int](numUsable) // Amount of memory to give to each worker
+ var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
+ var freeWorkers = (0 until numUsable).toIndexedSeq
+
+ def canLaunchExecutor(pos: Int): Boolean = {
+ usableWorkers(pos).coresFree - assignedCores(pos) >= coresPerExecutor &&
+ usableWorkers(pos).memoryFree - assignedMemory(pos) >= memoryPerExecutor
+ }
+
+ while (coresToAssign >= coresPerExecutor && freeWorkers.nonEmpty) {
+ freeWorkers = freeWorkers.filter(canLaunchExecutor)
+ freeWorkers.foreach { pos =>
+ var keepScheduling = true
+ while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) {
+ coresToAssign -= coresPerExecutor
+ assignedCores(pos) += coresPerExecutor
+ // If cores per executor is not set, we are assigning 1 core at a time
+ // without actually meaning to launch 1 executor for each core assigned
+ if (app.desc.coresPerExecutor.isDefined) {
+ assignedMemory(pos) += memoryPerExecutor
+ }
+
+ // Spreading out an application means spreading out its executors across as
+ // many workers as possible. If we are not spreading out, then we should keep
+ // scheduling executors on this worker until we use all of its resources.
+ // Otherwise, just move on to the next worker.
+ if (spreadOutApps) {
+ keepScheduling = false
}
- pos = (pos + 1) % numUsable
- }
- // Now that we've decided how many cores to give on each node, let's actually give them
- for (pos <- 0 until numUsable if assigned(pos) > 0) {
- allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos))
}
}
- } else {
- // Pack each app into as few workers as possible until we've assigned all its cores
- for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) {
- for (app <- waitingApps if app.coresLeft > 0) {
- allocateWorkerResourceToExecutors(app, app.coresLeft, worker)
- }
+ }
+ assignedCores
+ }
+
+ /**
+ * Schedule and launch executors on workers
+ */
+ private def startExecutorsOnWorkers(): Unit = {
+ // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
+ // in the queue, then the second app, etc.
+ for (app <- waitingApps if app.coresLeft > 0) {
+ val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor
+ // Filter out workers that don't have enough resources to launch an executor
+ val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
+ .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB &&
+ worker.coresFree >= coresPerExecutor.getOrElse(1))
+ .sortBy(_.coresFree).reverse
+ val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps)
+
+ // Now that we've decided how many cores to allocate on each worker, let's allocate them
+ for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) {
+ allocateWorkerResourceToExecutors(
+ app, assignedCores(pos), coresPerExecutor, usableWorkers(pos))
}
}
}
@@ -595,19 +630,22 @@ private[master] class Master(
/**
* Allocate a worker's resources to one or more executors.
* @param app the info of the application which the executors belong to
- * @param coresToAllocate cores on this worker to be allocated to this application
+ * @param assignedCores number of cores on this worker for this application
+ * @param coresPerExecutor number of cores per executor
* @param worker the worker info
*/
private def allocateWorkerResourceToExecutors(
app: ApplicationInfo,
- coresToAllocate: Int,
+ assignedCores: Int,
+ coresPerExecutor: Option[Int],
worker: WorkerInfo): Unit = {
- val memoryPerExecutor = app.desc.memoryPerExecutorMB
- val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate)
- var coresLeft = coresToAllocate
- while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) {
- val exec = app.addExecutor(worker, coresPerExecutor)
- coresLeft -= coresPerExecutor
+ // If the number of cores per executor is specified, we divide the cores assigned
+ // to this worker evenly among the executors with no remainder.
+ // Otherwise, we launch a single executor that grabs all the assignedCores on this worker.
+ val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1)
+ val coresToAssign = coresPerExecutor.getOrElse(assignedCores)
+ for (i <- 1 to numExecutors) {
+ val exec = app.addExecutor(worker, coresToAssign)
launchExecutor(worker, exec)
app.state = ApplicationState.RUNNING
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 435b9b12f83b8..44cefbc77f08e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
* Print usage and exit JVM with the given exit code.
*/
private def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"Usage: Master [options]\n" +
"\n" +
@@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
" --webui-port PORT Port for web UI (default: 8080)\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index a03d460509e03..58a00bceee6af 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.master
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEnv
import scala.reflect.ClassTag
@@ -80,8 +81,11 @@ abstract class PersistenceEngine {
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ final def readPersistedData(
+ rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ rpcEnv.deserialize { () =>
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
}
def close() {}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
index 351db8fab2041..c4c3283fb73f7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -17,10 +17,9 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
-
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
/**
* ::DeveloperApi::
@@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
*
*/
@DeveloperApi
-abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) {
/**
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
@@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial
* LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
* recovery is made by restoring from filesystem.
*/
-private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
@@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer:
}
}
-private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) {
def createPersistenceEngine(): PersistenceEngine = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 471811037e5e2..f751966605206 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -105,4 +105,6 @@ private[spark] class WorkerInfo(
def setState(state: WorkerState.Value): Unit = {
this.state = state
}
+
+ def isAlive(): Boolean = this.state == WorkerState.ALIVE
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 328d95a7a0c68..563831cc6b8dd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.serializer.Serializer
-private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
+private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer)
extends PersistenceEngine
with Logging {
@@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat
}
private def serializeIntoFile(path: String, value: AnyRef) {
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+ val serialized = serializer.newInstance().serialize(value)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes)
}
private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
try {
- Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
index 894cb78d8591a..5accaf78d0a51 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
@@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
case ("--master" | "-m") :: value :: tail =>
if (!value.startsWith("mesos://")) {
+ // scalastyle:off println
System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)")
+ // scalastyle:on println
System.exit(1)
}
masterUrl = value.stripPrefix("mesos://")
@@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
case Nil => {
if (masterUrl == null) {
+ // scalastyle:off println
System.err.println("--master is required")
+ // scalastyle:on println
printUsageAndExit(1)
}
}
@@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
}
private def printUsageAndExit(exitCode: Int): Unit = {
+ // scalastyle:off println
System.err.println(
"Usage: MesosClusterDispatcher [options]\n" +
"\n" +
@@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
" Zookeeper for persistence\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
index e6615a3174ce1..ef5a7e35ad562 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
@@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage {
*/
def fromJson(json: String): SubmitRestProtocolMessage = {
val className = parseAction(json)
- val clazz = Class.forName(packagePrefix + "." + className)
+ val clazz = Utils.classForName(packagePrefix + "." + className)
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
fromJson(json, clazz)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index d1a12b01e78f7..6799f78ec0c19 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -53,14 +53,16 @@ object DriverWrapper {
Thread.currentThread.setContextClassLoader(loader)
// Delegate to supplied main class
- val clazz = Class.forName(mainClass, true, loader)
+ val clazz = Utils.classForName(mainClass)
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])
rpcEnv.shutdown()
case _ =>
+ // scalastyle:off println
System.err.println("Usage: DriverWrapper [options]")
+ // scalastyle:on println
System.exit(-1)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 1d2ecab517613..5181142c5f80e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
* Print usage and exit JVM with the given exit code.
*/
def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"Usage: Worker [options] \n" +
"\n" +
@@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
" --webui-port PORT Port for web UI (default: 8081)\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
@@ -147,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val ibmVendor = System.getProperty("java.vendor").contains("IBM")
var totalMb = 0
try {
+ // scalastyle:off classforname
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
@@ -157,10 +160,13 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
+ // scalastyle:on classforname
} catch {
case e: Exception => {
totalMb = 2*1024
+ // scalastyle:off println
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
+ // scalastyle:on println
}
}
// Leave out 1 GB for the operating system, but don't return a negative memory size
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 34d4cfdca7732..fcd76ec52742a 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -235,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
argv = tail
case Nil =>
case tail =>
+ // scalastyle:off println
System.err.println(s"Unrecognized options: ${tail.mkString(" ")}")
+ // scalastyle:on println
printUsageAndExit()
}
}
@@ -249,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
private def printUsageAndExit() = {
+ // scalastyle:off println
System.err.println(
"""
|"Usage: CoarseGrainedExecutorBackend [options]
@@ -262,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
| --worker-url
| --user-class-path
|""".stripMargin)
+ // scalastyle:on println
System.exit(1)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 8f916e0502ecb..7bc7fce7ae8dd 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -209,15 +209,19 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = try {
- task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ var threwException = true
+ val (value, accumUpdates) = try {
+ val res = task.run(
+ taskAttemptId = taskId,
+ attemptNumber = attemptNumber,
+ metricsSystem = env.metricsSystem)
+ threwException = false
+ res
} finally {
- // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
- // when changing this, make sure to update both copies.
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
- if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
@@ -247,7 +251,6 @@ private[spark] class Executor(
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
- val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
@@ -310,12 +313,6 @@ private[spark] class Executor(
}
} finally {
- // Release memory used by this thread for shuffles
- env.shuffleMemoryManager.releaseMemoryForThisThread()
- // Release memory used by this thread for unrolling blocks
- env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
- // Release memory used by this thread for accumulators
- Accumulators.clear()
runningTasks.remove(taskId)
}
}
@@ -356,7 +353,7 @@ private[spark] class Executor(
logInfo("Using REPL class URI: " + classUri)
try {
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
- val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
+ val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[String],
classOf[ClassLoader], classOf[Boolean])
@@ -424,6 +421,7 @@ private[spark] class Executor(
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+ metrics.updateAccumulators()
if (isLocal) {
// JobProgressListener will hold an reference of it during
@@ -443,7 +441,7 @@ private[spark] class Executor(
try {
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message)
if (response.reregisterBlockManager) {
- logWarning("Told to re-register on heartbeat")
+ logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index a3b4561b07e7f..42207a9553592 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,11 +17,15 @@
package org.apache.spark.executor
+import java.io.{IOException, ObjectInputStream}
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.storage.{BlockId, BlockStatus}
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -210,10 +214,42 @@ class TaskMetrics extends Serializable {
private[spark] def updateInputMetrics(): Unit = synchronized {
inputMetrics.foreach(_.updateBytesRead())
}
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ // Get the hostname from cached data, since hostname is the order of number of nodes in
+ // cluster, so using cached hostname will decrease the object number and alleviate the GC
+ // overhead.
+ _hostname = TaskMetrics.getCachedHostName(_hostname)
+ }
+
+ private var _accumulatorUpdates: Map[Long, Any] = Map.empty
+ @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
+
+ private[spark] def updateAccumulators(): Unit = synchronized {
+ _accumulatorUpdates = _accumulatorsUpdater()
+ }
+
+ /**
+ * Return the latest updates of accumulators in this task.
+ */
+ def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
+
+ private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
+ _accumulatorsUpdater = accumulatorsUpdater
+ }
}
private[spark] object TaskMetrics {
+ private val hostNameCache = new ConcurrentHashMap[String, String]()
+
def empty: TaskMetrics = new TaskMetrics
+
+ def getCachedHostName(host: String): String = {
+ val canonicalHost = hostNameCache.putIfAbsent(host, host)
+ if (canonicalHost != null) canonicalHost else host
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
index c219d21fbefa9..532850dd57716 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
/**
@@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat {
}
private[spark] class FixedLengthBinaryInputFormat
- extends FileInputFormat[LongWritable, BytesWritable] {
+ extends FileInputFormat[LongWritable, BytesWritable]
+ with Logging {
private var recordLength = -1
@@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat
recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
}
if (recordLength <= 0) {
- println("record length is less than 0, file cannot be split")
+ logDebug("record length is less than 0, file cannot be split")
false
} else {
true
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 0d8ac1f80a9f4..607d5a321efca 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -63,8 +63,7 @@ private[spark] object CompressionCodec {
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codec = try {
- val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
- .getConstructor(classOf[SparkConf])
+ val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
} catch {
case e: ClassNotFoundException => None
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 818f7a4c8d422..87df42748be44 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.util.{Utils => SparkUtils}
private[spark]
trait SparkHadoopMapRedUtil {
@@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil {
private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
- Class.forName(first)
+ SparkUtils.classForName(first)
} catch {
case e: ClassNotFoundException =>
- Class.forName(second)
+ SparkUtils.classForName(second)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 390d148bc97f9..943ebcb7bd0a1 100644
--- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
+import org.apache.spark.util.Utils
private[spark]
trait SparkHadoopMapReduceUtil {
@@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil {
isMap: Boolean,
taskId: Int,
attemptId: Int): TaskAttemptID = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
+ val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
@@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil {
} catch {
case exc: NoSuchMethodException => {
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if (isMap) "MAP" else "REDUCE")
@@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil {
private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
- Class.forName(first)
+ Utils.classForName(first)
} catch {
case e: ClassNotFoundException =>
- Class.forName(second)
+ Utils.classForName(second)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index ed5131c79fdc5..4517f465ebd3b 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -20,6 +20,8 @@ package org.apache.spark.metrics
import java.util.Properties
import java.util.concurrent.TimeUnit
+import org.apache.spark.util.Utils
+
import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
@@ -140,6 +142,9 @@ private[spark] class MetricsSystem private (
} else { defaultName }
}
+ def getSourcesByName(sourceName: String): Seq[Source] =
+ sources.filter(_.sourceName == sourceName)
+
def registerSource(source: Source) {
sources += source
try {
@@ -166,7 +171,7 @@ private[spark] class MetricsSystem private (
sourceConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
try {
- val source = Class.forName(classPath).newInstance()
+ val source = Utils.classForName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
@@ -182,7 +187,7 @@ private[spark] class MetricsSystem private (
val classPath = kv._2.getProperty("class")
if (null != classPath) {
try {
- val sink = Class.forName(classPath)
+ val sink = Utils.classForName(classPath)
.getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
.newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
index 67a376102994c..79cb0640c8672 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
@@ -57,16 +57,6 @@ private[nio] class BlockMessage() {
}
def set(buffer: ByteBuffer) {
- /*
- println()
- println("BlockMessage: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
typ = buffer.getInt()
val idLength = buffer.getInt()
val idBuilder = new StringBuilder(idLength)
@@ -138,18 +128,6 @@ private[nio] class BlockMessage() {
buffers += data
}
- /*
- println()
- println("BlockMessage: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
Message.createBufferMessage(buffers)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
index 7d0806f0c2580..f1c9ea8b64ca3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
@@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
val newBlockMessages = new ArrayBuffer[BlockMessage]()
val buffer = bufferMessage.buffers(0)
buffer.clear()
- /*
- println()
- println("BlockMessageArray: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
while (buffer.remaining() > 0) {
val size = buffer.getInt()
logDebug("Creating block message of size " + size + " bytes")
@@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
logDebug("Buffer list:")
buffers.foreach((x: ByteBuffer) => logDebug("" + x))
- /*
- println()
- println("BlockMessageArray: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
Message.createBufferMessage(buffers)
}
}
-private[nio] object BlockMessageArray {
+private[nio] object BlockMessageArray extends Logging {
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
@@ -123,10 +101,10 @@ private[nio] object BlockMessageArray {
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
- println("Block message array created")
+ logDebug("Block message array created")
val bufferMessage = blockMessageArray.toBufferMessage
- println("Converted to buffer message")
+ logDebug("Converted to buffer message")
val totalSize = bufferMessage.size
val newBuffer = ByteBuffer.allocate(totalSize)
@@ -138,10 +116,11 @@ private[nio] object BlockMessageArray {
})
newBuffer.flip
val newBufferMessage = Message.createBufferMessage(newBuffer)
- println("Copied to new buffer message, size = " + newBufferMessage.size)
+ logDebug("Copied to new buffer message, size = " + newBufferMessage.size)
val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
- println("Converted back to block message array")
+ logDebug("Converted back to block message array")
+ // scalastyle:off println
newBlockMessageArray.foreach(blockMessage => {
blockMessage.getType match {
case BlockMessage.TYPE_PUT_BLOCK => {
@@ -154,6 +133,7 @@ private[nio] object BlockMessageArray {
}
}
})
+ // scalastyle:on println
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index c0bca2c4bc994..9143918790381 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager {
val conf = new SparkConf
val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ // scalastyle:off println
println("Received [" + msg + "] from [" + id + "]")
+ // scalastyle:on println
None
})
@@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager {
System.gc()
}
+ // scalastyle:off println
def testSequentialSending(manager: ConnectionManager) {
println("--------------------------")
println("Sequential Sending")
@@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager {
println()
}
}
+ // scalastyle:on println
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 658e8c8b89318..130b58882d8ee 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -94,13 +94,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
override def getDependencies: Seq[Dependency[_]] = {
- rdds.map { rdd: RDD[_ <: Product2[K, _]] =>
+ rdds.map { rdd: RDD[_] =>
if (rdd.partitioner == Some(part)) {
logDebug("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
+ new ShuffleDependency[K, Any, CoGroupCombiner](
+ rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer)
}
}
}
@@ -133,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
- case oneToOneDependency: OneToOneDependency[Product2[K, Any]] =>
+ case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked =>
val dependencyPartition = split.narrowDeps(depNum).get.split
// Read them from the parent
val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 663eebb8e4191..90d9735cb3f69 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition(
* the preferred location of each new partition overlaps with as many preferred locations of its
* parent partitions
* @param prev RDD to be coalesced
- * @param maxPartitions number of desired partitions in the coalesced RDD
+ * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive)
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
@@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag](
balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
+ require(maxPartitions > 0 || maxPartitions == prev.partitions.length,
+ s"Number of partitions ($maxPartitions) must be positive.")
+
override def getPartitions: Array[Partition] = {
val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index bee59a437f120..f1c17369cb48c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging {
private[spark] class SplitInfoReflections {
val inputSplitWithLocationInfo =
- Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
- val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit")
val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
- val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo")
val isInMemory = splitLocationInfo.getMethod("isInMemory")
val getLocation = splitLocationInfo.getMethod("getLocation")
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index f827270ee6a44..f83a051f5da11 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -128,7 +128,7 @@ class NewHadoopRDD[K, V](
configurable.setConf(conf)
case _ =>
}
- val reader = format.createRecordReader(
+ private var reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
@@ -141,6 +141,12 @@ class NewHadoopRDD[K, V](
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
+ if (finished) {
+ // Close and release the reader here; close() will also be called when the task
+ // completes, but for tasks that read from many files, it helps to release the
+ // resources early.
+ close()
+ }
havePair = !finished
}
!finished
@@ -159,18 +165,23 @@ class NewHadoopRDD[K, V](
private def close() {
try {
- reader.close()
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
- // If we can't get the bytes read from the FS stats, fall back to the split size,
- // which may be inaccurate.
- try {
- inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
+ if (reader != null) {
+ // Close reader and release it
+ reader.close()
+ reader = null
+
+ if (bytesReadCallback.isDefined) {
+ inputMetrics.updateBytesRead()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 91a6a2d039852..326fafb230a40 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -881,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
buf
} : Seq[V]
- val res = self.context.runJob(self, process, Array(index), false)
+ val res = self.context.runJob(self, process, Array(index))
res(0)
case None =>
self.filter(_._1 == key).map(_._2).collect()
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index dc60d48927624..3bb9998e1db44 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag](
new Thread("stderr reader for " + command) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ // scalastyle:off println
System.err.println(line)
+ // scalastyle:on println
}
}
}.start()
@@ -131,8 +133,10 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
+ TaskContext.setTaskContext(context)
val out = new PrintWriter(proc.getOutputStream)
+ // scalastyle:off println
// input the pipe context firstly
if (printPipeContext != null) {
printPipeContext(out.println(_))
@@ -144,6 +148,7 @@ private[spark] class PipedRDD[T: ClassTag](
out.println(elem)
}
}
+ // scalastyle:on println
out.close()
}
}.start()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 9f7ebae3e9af3..6d61d227382d7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -897,7 +897,7 @@ abstract class RDD[T: ClassTag](
*/
def toLocalIterator: Iterator[T] = withScope {
def collectPartition(p: Int): Array[T] = {
- sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
}
(0 until partitions.length).iterator.flatMap(i => collectPartition(i))
}
@@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag](
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
- while (numPartitions > scale + numPartitions / scale) {
+
+ // Don't trigger TreeAggregation when it doesn't save wall-clock time
+ while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
@@ -1273,7 +1275,7 @@ abstract class RDD[T: ClassTag](
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
- val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+ val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
similarity index 81%
rename from sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
index 2bdc341021256..35e44cb59c1be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
@@ -15,28 +15,28 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources
+package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
+import scala.reflect.ClassTag
+
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
-import org.apache.spark.broadcast.Broadcast
-
-import org.apache.spark.{Partition => SparkPartition, _}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
-import org.apache.spark.rdd.{RDD, HadoopRDD}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Partition => SparkPartition, _}
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{SerializableConfiguration, Utils}
-import scala.reflect.ClassTag
private[spark] class SqlNewHadoopPartition(
rddId: Int,
@@ -63,7 +63,7 @@ private[spark] class SqlNewHadoopPartition(
* changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be
* folded into core.
*/
-private[sql] class SqlNewHadoopRDD[K, V](
+private[spark] class SqlNewHadoopRDD[K, V](
@transient sc : SparkContext,
broadcastedConf: Broadcast[SerializableConfiguration],
@transient initDriverSideJobFuncOpt: Option[Job => Unit],
@@ -129,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V](
val inputMetrics = context.taskMetrics
.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+ // Sets the thread local variable for the file's name
+ split.serializableHadoopSplit.value match {
+ case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDD.unsetInputFileName()
+ }
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -148,7 +154,7 @@ private[sql] class SqlNewHadoopRDD[K, V](
configurable.setConf(conf)
case _ =>
}
- val reader = format.createRecordReader(
+ private var reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
@@ -161,6 +167,12 @@ private[sql] class SqlNewHadoopRDD[K, V](
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
+ if (finished) {
+ // Close and release the reader here; close() will also be called when the task
+ // completes, but for tasks that read from many files, it helps to release the
+ // resources early.
+ close()
+ }
havePair = !finished
}
!finished
@@ -179,18 +191,24 @@ private[sql] class SqlNewHadoopRDD[K, V](
private def close() {
try {
- reader.close()
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
- // If we can't get the bytes read from the FS stats, fall back to the split size,
- // which may be inaccurate.
- try {
- inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
+ if (reader != null) {
+ reader.close()
+ reader = null
+
+ SqlNewHadoopRDD.unsetInputFileName()
+
+ if (bytesReadCallback.isDefined) {
+ inputMetrics.updateBytesRead()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
}
}
} catch {
@@ -241,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V](
}
private[spark] object SqlNewHadoopRDD {
+
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
+ override protected def initialValue(): UTF8String = UTF8String.fromString("")
+ }
+
+ def getInputFileName(): UTF8String = inputFileName.get()
+
+ private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
+
+ private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
+
/**
* Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
* the given function rather than the index of the partition.
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index 523aaf2b860b5..e277ae28d588f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L
prev.context.runJob(
prev,
Utils.getIteratorSize _,
- 0 until n - 1, // do not need to count the last partition
- allowLocal = false
+ 0 until n - 1 // do not need to count the last partition
).scanLeft(0L)(_ + _)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 1709bdf560b6f..29debe8081308 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -39,8 +39,7 @@ private[spark] object RpcEnv {
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
- Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
- newInstance().asInstanceOf[RpcEnvFactory]
+ Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
def create(
@@ -140,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* creating it manually because different [[RpcEnv]] may have different formats.
*/
def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
+
+ /**
+ * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
+ * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
+ */
+ def deserialize[T](deserializationAction: () => T): T
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index f2d87f68341af..fc17542abf81d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
-import com.google.common.util.concurrent.MoreExecutors
+import akka.serialization.JavaSerializer
import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
@@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] (
}
override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
+
+ override def deserialize[T](deserializationAction: () => T): T = {
+ JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
+ deserializationAction()
+ }
+ }
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef(
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
+ final override def equals(that: Any): Boolean = that match {
+ case other: AkkaRpcEndpointRef => actorRef == other.actorRef
+ case _ => false
+ }
+
+ final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6841fa835747f..c4fa277c21254 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -22,7 +22,8 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
+import scala.collection.Map
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
@@ -37,7 +38,6 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -127,10 +127,6 @@ class DAGScheduler(
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
-
- /** If enabled, we may run certain actions like take() and first() locally. */
- private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
-
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
@@ -514,7 +510,6 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
- allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties): JobWaiter[U] = {
// Check to make sure we are not launching a task on a partition that does not exist.
@@ -534,7 +529,7 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
- jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter,
+ jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
waiter
}
@@ -544,11 +539,10 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
- allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties): Unit = {
val start = System.nanoTime
- val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
+ val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded =>
logInfo("Job %d finished: %s, took %f s".format
@@ -556,6 +550,9 @@ class DAGScheduler(
case JobFailed(exception: Exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
+ // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
+ val callerStackTrace = Thread.currentThread().getStackTrace.tail
+ exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
throw exception
}
}
@@ -572,8 +569,7 @@ class DAGScheduler(
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
eventProcessLoop.post(JobSubmitted(
- jobId, rdd, func2, partitions, allowLocal = false, callSite, listener,
- SerializationUtils.clone(properties)))
+ jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties)))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -650,73 +646,6 @@ class DAGScheduler(
}
}
- /**
- * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
- * We run the operation in a separate thread just in case it takes a bunch of time, so that we
- * don't block the DAGScheduler event loop or other concurrent jobs.
- */
- protected def runLocally(job: ActiveJob) {
- logInfo("Computing the requested partition locally")
- new Thread("Local computation of job " + job.jobId) {
- override def run() {
- runLocallyWithinThread(job)
- }
- }.start()
- }
-
- // Broken out for easier testing in DAGSchedulerSuite.
- protected def runLocallyWithinThread(job: ActiveJob) {
- var jobResult: JobResult = JobSucceeded
- try {
- val rdd = job.finalStage.rdd
- val split = rdd.partitions(job.partitions(0))
- val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
- val taskContext =
- new TaskContextImpl(
- job.finalStage.id,
- job.partitions(0),
- taskAttemptId = 0,
- attemptNumber = 0,
- taskMemoryManager = taskMemoryManager,
- runningLocally = true)
- TaskContext.setTaskContext(taskContext)
- try {
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- job.listener.taskSucceeded(0, result)
- } finally {
- taskContext.markTaskCompleted()
- TaskContext.unset()
- // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
- // make sure to update both copies.
- val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
- if (freedMemory > 0) {
- if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
- throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
- } else {
- logError(s"Managed memory leak detected; size = $freedMemory bytes")
- }
- }
- }
- } catch {
- case e: Exception =>
- val exception = new SparkDriverExecutionException(e)
- jobResult = JobFailed(exception)
- job.listener.jobFailed(exception)
- case oom: OutOfMemoryError =>
- val exception = new SparkException("Local job aborted due to out of memory error", oom)
- jobResult = JobFailed(exception)
- job.listener.jobFailed(exception)
- } finally {
- val s = job.finalStage
- // clean up data structures that were populated for a local job,
- // but that won't get cleaned up via the normal paths through
- // completion events or stage abort
- stageIdToStage -= s.id
- jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult))
- }
- }
-
/** Finds the earliest-created active job that needs the stage */
// TODO: Probably should actually find among the active jobs that need this
// stage the one with the highest priority (highest-priority pool, earliest created).
@@ -779,7 +708,6 @@ class DAGScheduler(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
- allowLocal: Boolean,
callSite: CallSite,
listener: JobListener,
properties: Properties) {
@@ -797,29 +725,20 @@ class DAGScheduler(
if (finalStage != null) {
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
- logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
- job.jobId, callSite.shortForm, partitions.length, allowLocal))
+ logInfo("Got job %s (%s) with %d output partitions".format(
+ job.jobId, callSite.shortForm, partitions.length))
logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
- val shouldRunLocally =
- localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
val jobSubmissionTime = clock.getTimeMillis()
- if (shouldRunLocally) {
- // Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(
- SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
- runLocally(job)
- } else {
- jobIdToActiveJob(jobId) = job
- activeJobs += job
- finalStage.resultOfJob = Some(job)
- val stageIds = jobIdToStageIds(jobId).toArray
- val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
- listenerBus.post(
- SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
- submitStage(finalStage)
- }
+ jobIdToActiveJob(jobId) = job
+ activeJobs += job
+ finalStage.resultOfJob = Some(job)
+ val stageIds = jobIdToStageIds(jobId).toArray
+ val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
+ submitStage(finalStage)
}
submitWaitingStages()
}
@@ -853,7 +772,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
-
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
@@ -872,8 +790,28 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
- stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
outputCommitCoordinator.stageStart(stage.id)
+ val taskIdToLocations = try {
+ stage match {
+ case s: ShuffleMapStage =>
+ partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
+ case s: ResultStage =>
+ val job = s.resultOfJob.get
+ partitionsToCompute.map { id =>
+ val p = job.partitions(id)
+ (id, getPreferredLocs(stage.rdd, p))
+ }.toMap
+ }
+ } catch {
+ case NonFatal(e) =>
+ stage.makeNewStageAttempt(partitionsToCompute.size)
+ listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ runningStages -= stage
+ return
+ }
+
+ stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
@@ -912,9 +850,9 @@ class DAGScheduler(
stage match {
case stage: ShuffleMapStage =>
partitionsToCompute.map { id =>
- val locs = getPreferredLocs(stage.rdd, id)
+ val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}
case stage: ResultStage =>
@@ -922,8 +860,8 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
- val locs = getPreferredLocs(stage.rdd, p)
- new ResultTask(stage.id, taskBinary, part, locs, id)
+ val locs = taskIdToLocations(id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
@@ -937,8 +875,8 @@ class DAGScheduler(
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
- taskScheduler.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties))
+ taskScheduler.submitTasks(new TaskSet(
+ tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
@@ -978,11 +916,9 @@ class DAGScheduler(
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
- val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
- val stringValue = Accumulators.stringifyValue(acc.value)
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
+ stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
+ AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
}
}
} catch {
@@ -1009,7 +945,7 @@ class DAGScheduler(
// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
if (event.reason != Success) {
- val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1)
+ val attemptId = task.stageAttemptId
listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
event.taskInfo, event.taskMetrics))
}
@@ -1065,10 +1001,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
- logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}
+
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
@@ -1128,38 +1065,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
- // It is likely that we receive multiple FetchFailed for a single stage (because we have
- // multiple tasks running concurrently on different executors). In that case, it is possible
- // the fetch failure has already been handled by the scheduler.
- if (runningStages.contains(failedStage)) {
- logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
- s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some(failureMessage))
- }
+ if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
+ logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
+ s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
+ s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
+ } else {
- if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
- } else if (failedStages.isEmpty) {
- // Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled.
- // TODO: Cancel running tasks in the stage
- logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
- s"$failedStage (${failedStage.name}) due to fetch failure")
- messageScheduler.schedule(new Runnable {
- override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
- }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
- }
- failedStages += failedStage
- failedStages += mapStage
- // Mark the map whose fetch failed as broken in the map stage
- if (mapId != -1) {
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
- }
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+ // possible the fetch failure has already been handled by the scheduler.
+ if (runningStages.contains(failedStage)) {
+ logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
+ s"due to a fetch failure from $mapStage (${mapStage.name})")
+ markStageAsFinished(failedStage, Some(failureMessage))
+ } else {
+ logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
+ s"longer running")
+ }
- // TODO: mark the executor as failed only if there were lots of fetch failures on it
- if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty) {
+ // Don't schedule an event to resubmit failed stages if failed isn't empty, because
+ // in that case the event will already have been scheduled.
+ // TODO: Cancel running tasks in the stage
+ logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
+ s"$failedStage (${failedStage.name}) due to fetch failure")
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+ }
+ failedStages += failedStage
+ failedStages += mapStage
+ // Mark the map whose fetch failed as broken in the map stage
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
+
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ }
}
case commitDenied: TaskCommitDenied =>
@@ -1471,9 +1418,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
}
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
- case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
- listener, properties)
+ case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
+ dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
case StageCancelled(stageId) =>
dagScheduler.handleStageCancellation(stageId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 2b6f7e4205c32..a213d419cf033 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.util.Properties
-import scala.collection.mutable.Map
+import scala.collection.Map
import scala.language.existentials
import org.apache.spark._
@@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
- allowLocal: Boolean,
callSite: CallSite,
listener: JobListener,
properties: Properties = null)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 529a5b2bf1a0d..5a06ef02f5c57 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -140,7 +140,9 @@ private[spark] class EventLoggingListener(
/** Log the event as JSON. */
private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
val eventJson = JsonProtocol.sparkEventToJson(event)
+ // scalastyle:off println
writer.foreach(_.println(compact(render(eventJson))))
+ // scalastyle:on println
if (flushLogger) {
writer.foreach(_.flush())
hadoopDataStream.foreach(hadoopFlushMethod.invoke(_))
@@ -197,6 +199,9 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
}
+ // No-op because logging every update would be overkill
+ override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {}
+
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index e55b76c36cc5f..f96eb8ca0ae00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
val date = new Date(System.currentTimeMillis())
writeInfo = dateFormat.get.format(date) + ": " + info
}
+ // scalastyle:off println
jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo))
+ // scalastyle:on println
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c9a124113961f..9c2606e278c54 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
- extends Task[U](stageId, partition.index) with Serializable {
+ extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index bd3dd23dfe1ac..14c8c00961487 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, partition.index) with Logging {
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 9620915f495ab..896f1743332f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
import org.apache.spark.util.{Distribution, Utils}
@DeveloperApi
@@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn
case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String)
extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent
+
/**
* Periodic updates from executors.
* @param execId executor id
@@ -215,6 +218,11 @@ trait SparkListener {
* Called when the driver removes an executor.
*/
def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
+
+ /**
+ * Called when the driver receives a block update info.
+ */
+ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 61e69ecc08387..04afde33f5aad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi
listener.onExecutorAdded(executorAdded)
case executorRemoved: SparkListenerExecutorRemoved =>
listener.onExecutorRemoved(executorRemoved)
+ case blockUpdated: SparkListenerBlockUpdated =>
+ listener.onBlockUpdated(blockUpdated)
case logStart: SparkListenerLogStart => // ignore event log metadata
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index c59d6e4f5bc04..40a333a3e06b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -62,22 +62,31 @@ private[spark] abstract class Stage(
var pendingTasks = new HashSet[Task[_]]
+ /** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0
val name = callSite.shortForm
val details = callSite.longForm
- /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */
- var latestInfo: StageInfo = StageInfo.fromStage(this)
+ /**
+ * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
+ * here, before any attempts have actually been created, because the DAGScheduler uses this
+ * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts
+ * have been created).
+ */
+ private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)
- /** Return a new attempt id, starting with 0. */
- def newAttemptId(): Int = {
- val id = nextAttemptId
+ /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
+ def makeNewStageAttempt(
+ numPartitionsToCompute: Int,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
+ _latestInfo = StageInfo.fromStage(
+ this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences)
nextAttemptId += 1
- id
}
- def attemptId: Int = nextAttemptId
+ /** Returns the StageInfo for the most recent attempt for this stage. */
+ def latestInfo: StageInfo = _latestInfo
override final def hashCode(): Int = id
override final def equals(other: Any): Boolean = other match {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e439d2a7e1229..24796c14300b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -34,7 +34,8 @@ class StageInfo(
val numTasks: Int,
val rddInfos: Seq[RDDInfo],
val parentIds: Seq[Int],
- val details: String) {
+ val details: String,
+ private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
/** Time when all tasks in the stage completed or when the stage was cancelled. */
@@ -70,16 +71,22 @@ private[spark] object StageInfo {
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
* sequence of narrow dependencies should also be associated with this Stage.
*/
- def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = {
+ def fromStage(
+ stage: Stage,
+ attemptId: Int,
+ numTasks: Option[Int] = None,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
+ ): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
new StageInfo(
stage.id,
- stage.attemptId,
+ attemptId,
stage.name,
numTasks.getOrElse(stage.numTasks),
rddInfos,
stage.parents.map(_.id),
- stage.details)
+ stage.details,
+ taskLocalityPreferences)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 15101c64f0503..1978305cfefbd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,7 +22,8 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.{TaskContextImpl, TaskContext}
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -43,34 +44,60 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+private[spark] abstract class Task[T](
+ val stageId: Int,
+ val stageAttemptId: Int,
+ var partitionId: Int) extends Serializable {
+
+ /**
+ * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
+ * local value.
+ */
+ type AccumulatorUpdates = Map[Long, Any]
/**
* Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
- * @return the result of the task
+ * @return the result of the task along with updates of Accumulators.
*/
- final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ final def run(
+ taskAttemptId: Long,
+ attemptNumber: Int,
+ metricsSystem: MetricsSystem)
+ : (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
taskMemoryManager = taskMemoryManager,
+ metricsSystem = metricsSystem,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
+ context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
try {
- runTask(context)
+ (runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
- TaskContext.unset()
+ try {
+ Utils.tryLogNonFatalError {
+ // Release memory used by this thread for shuffles
+ SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
+ }
+ Utils.tryLogNonFatalError {
+ // Release memory used by this thread for unrolling blocks
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
+ }
+ } finally {
+ TaskContext.unset()
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 8b2a742b96988..b82c7f3fa54f8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,7 +20,8 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.executor.TaskMetrics
@@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
if (numUpdates == 0) {
accumUpdates = null
} else {
- accumUpdates = Map()
+ val _accumUpdates = mutable.Map[Long, Any]()
for (i <- 0 until numUpdates) {
- accumUpdates(in.readLong()) = in.readObject()
+ _accumUpdates(in.readLong()) = in.readObject()
}
+ accumUpdates = _accumUpdates
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
valueObjectDeserialized = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index ed3dde0fc3055..1705e7f962de2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
- val taskIdToTaskSetId = new HashMap[Long, String]
+ private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]
@volatile private var hasReceivedTask = false
@@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
- activeTaskSets(taskSet.id) = manager
+ val stage = taskSet.stageId
+ val stageTaskSets =
+ taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
+ stageTaskSets(taskSet.stageAttemptId) = manager
+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
+ ts.taskSet != taskSet && !ts.isZombie
+ }
+ if (conflictingTaskSet) {
+ throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
+ s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
+ }
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
if (!isLocal && !hasReceivedTask) {
@@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- tsm.runningTasksSet.foreach { tid =>
- val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId, interruptThread)
+ taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
+ attempts.foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ tsm.runningTasksSet.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId, interruptThread)
+ }
+ tsm.abort("Stage %s cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
}
- tsm.abort("Stage %s cancelled".format(stageId))
- logInfo("Stage %d was cancelled".format(stageId))
}
}
@@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
- activeTaskSets -= manager.taskSet.id
+ taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
+ taskSetsForStage -= manager.taskSet.stageAttemptId
+ if (taskSetsForStage.isEmpty) {
+ taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
+ }
+ }
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
@@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
@@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
- taskIdToTaskSetId.get(tid) match {
- case Some(taskSetId) =>
+ taskIdToTaskSetManager.get(tid) match {
+ case Some(taskSet) =>
if (TaskState.isFinished(state)) {
- taskIdToTaskSetId.remove(tid)
+ taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
- activeTaskSets.get(taskSetId).foreach { taskSet =>
- if (state == TaskState.FINISHED) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
- } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
- }
+ if (state == TaskState.FINISHED) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
- "likely the result of receiving duplicate task finished status updates)")
- .format(state, tid))
+ "likely the result of receiving duplicate task finished status updates)")
+ .format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
- taskIdToTaskSetId.get(id)
- .flatMap(activeTaskSets.get)
- .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
+ taskIdToTaskSetManager.get(id).map { taskSetMgr =>
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+ }
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
def error(message: String) {
synchronized {
- if (activeTaskSets.nonEmpty) {
+ if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
- for ((taskSetId, manager) <- activeTaskSets) {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.values
+ manager <- attempts.values
+ } {
try {
manager.abort(message)
} catch {
@@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
+ private[scheduler] def taskSetManagerForAttempt(
+ stageId: Int,
+ stageAttemptId: Int): Option[TaskSetManager] = {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.get(stageId)
+ manager <- attempts.get(stageAttemptId)
+ } yield {
+ manager
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156f53..be8526ba9b94f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
- val attempt: Int,
+ val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
- val id: String = stageId + "." + attempt
+ val id: String = stageId + "." + stageAttemptId
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 4be1eda2e9291..06f5438433b6e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -86,7 +86,11 @@ private[spark] object CoarseGrainedClusterMessages {
// Request executors by specifying the new total number of executors desired
// This includes executors already pending or running
- case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage
+ case class RequestExecutors(
+ requestedTotal: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int])
+ extends CoarseGrainedClusterMessage
case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 7c7f70d8a193b..bd89160af4ffa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -66,6 +66,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
+ // A map to store hostname with its possible task number running on it
+ protected var hostToLocalTaskCount: Map[String, Int] = Map.empty
+
+ // The number of pending tasks which is locality required
+ protected var localityAwareTasks = 0
+
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
@@ -169,9 +175,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on all executors
private def makeOffers() {
- launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) =>
+ // Filter out executors under killing
+ val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_))
+ val workOffers = activeExecutors.map { case (id, executorData) =>
new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
- }.toSeq))
+ }.toSeq
+ launchTasks(scheduler.resourceOffers(workOffers))
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
@@ -181,9 +190,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on just one executor
private def makeOffers(executorId: String) {
- val executorData = executorDataMap(executorId)
- launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))))
+ // Filter out executors under killing
+ if (!executorsPendingToRemove.contains(executorId)) {
+ val executorData = executorDataMap(executorId)
+ val workOffers = Seq(
+ new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+ launchTasks(scheduler.resourceOffers(workOffers))
+ }
}
// Launch tasks returned by a set of resource offers
@@ -191,15 +204,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
- val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
- scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
+ scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
- taskSet.abort(msg)
+ taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
@@ -229,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
scheduler.executorLost(executorId, SlaveLost(reason))
listenerBus.post(
SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason))
- case None => logError(s"Asked to remove non-existent executor $executorId")
+ case None => logInfo(s"Asked to remove non-existent executor $executorId")
}
}
@@ -333,6 +345,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
logDebug(s"Number of pending executors is now $numPendingExecutors")
+
numPendingExecutors += numAdditionalExecutors
// Account for executors pending to be added or removed
val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
@@ -340,16 +353,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
/**
- * Express a preference to the cluster manager for a given total number of executors. This can
- * result in canceling pending requests or filing additional requests.
- * @return whether the request is acknowledged.
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
+ * @return whether the request is acknowledged by the cluster manager.
*/
- final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized {
+ final override def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int]
+ ): Boolean = synchronized {
if (numExecutors < 0) {
throw new IllegalArgumentException(
"Attempted to request a negative number of executor(s) " +
s"$numExecutors from the cluster manager. Please specify a positive number!")
}
+
+ this.localityAwareTasks = localityAwareTasks
+ this.hostToLocalTaskCount = hostToLocalTaskCount
+
numPendingExecutors =
math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
doRequestTotalExecutors(numExecutors)
@@ -371,26 +401,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Request that the cluster manager kill the specified executors.
- * Return whether the kill request is acknowledged.
+ * @return whether the kill request is acknowledged.
*/
final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
+ killExecutors(executorIds, replace = false)
+ }
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ *
+ * @param executorIds identifiers of executors to kill
+ * @param replace whether to replace the killed executors with new ones
+ * @return whether the kill request is acknowledged.
+ */
+ final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
- val filteredExecutorIds = new ArrayBuffer[String]
- executorIds.foreach { id =>
- if (executorDataMap.contains(id)) {
- filteredExecutorIds += id
- } else {
- logWarning(s"Executor to kill $id does not exist!")
- }
+ val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
+ unknownExecutors.foreach { id =>
+ logWarning(s"Executor to kill $id does not exist!")
+ }
+
+ // If we do not wish to replace the executors we kill, sync the target number of executors
+ // with the cluster manager to avoid allocating new ones. When computing the new target,
+ // take into account executors that are pending to be added or removed.
+ if (!replace) {
+ doRequestTotalExecutors(numExistingExecutors + numPendingExecutors
+ - executorsPendingToRemove.size - knownExecutors.size)
}
- // Killing executors means effectively that we want less executors than before, so also update
- // the target number of executors to avoid having the backend allocate new ones.
- val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
- - filteredExecutorIds.size)
- doRequestTotalExecutors(newTotal)
- executorsPendingToRemove ++= filteredExecutorIds
- doKillExecutors(filteredExecutorIds)
+ executorsPendingToRemove ++= knownExecutors
+ doKillExecutors(knownExecutors)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index bc67abb5df446..044f6288fabdd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -53,7 +53,8 @@ private[spark] abstract class YarnSchedulerBackend(
* This includes executors already pending or running.
*/
override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal))
+ yarnSchedulerEndpoint.askWithRetry[Boolean](
+ RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
}
/**
@@ -108,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend(
case AddWebUIFilter(filterName, filterParams, proxyBase) =>
addWebUIFilter(filterName, filterParams, proxyBase)
+ case RemoveExecutor(executorId, reason) =>
+ removeExecutor(executorId, reason)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 6b8edca5aa485..b7fde0d9b3265 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -18,11 +18,13 @@
package org.apache.spark.scheduler.cluster.mesos
import java.io.File
+import java.util.concurrent.locks.ReentrantLock
import java.util.{Collections, List => JList}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
+import com.google.common.collect.HashBiMap
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.{Scheduler => MScheduler, _}
import org.apache.spark.rpc.RpcAddress
@@ -60,12 +62,34 @@ private[spark] class CoarseMesosSchedulerBackend(
val slaveIdsWithExecutors = new HashSet[String]
- val taskIdToSlaveId = new HashMap[Int, String]
- val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
+ val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String]
+ // How many times tasks on each slave failed
+ val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
+ /**
+ * The total number of executors we aim to have. Undefined when not using dynamic allocation
+ * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]].
+ */
+ private var executorLimitOption: Option[Int] = None
+
+ /**
+ * Return the current executor limit, which may be [[Int.MaxValue]]
+ * before properly initialized.
+ */
+ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue)
+
+ private val pendingRemovedSlaveIds = new HashSet[String]
+
+ // private lock object protecting mutable state above. Using the intrinsic lock
+ // may lead to deadlocks since the superclass might also try to lock
+ private val stateLock = new ReentrantLock
val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0)
+ // Offer constraints
+ private val slaveOfferConstraints =
+ parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
+
var nextMesosTaskId = 0
@volatile var appId: String = _
@@ -78,11 +102,12 @@ private[spark] class CoarseMesosSchedulerBackend(
override def start() {
super.start()
- val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
- startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo)
+ val driver = createSchedulerDriver(
+ master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf)
+ startScheduler(driver)
}
- def createCommand(offer: Offer, numCores: Int): CommandInfo = {
+ def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = {
val executorSparkHome = conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome())
.getOrElse {
@@ -116,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = sc.env.rpcEnv.uriOf(
- SparkEnv.driverActorSystemName,
- RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val uri = conf.getOption("spark.executor.uri")
.orElse(Option(System.getenv("SPARK_EXECUTOR_URI")))
@@ -129,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend(
command.setValue(
"%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
.format(prefixEnv, runScript) +
- s" --driver-url $driverUrl" +
+ s" --driver-url $driverURL" +
s" --executor-id ${offer.getSlaveId.getValue}" +
s" --hostname ${offer.getHostname}" +
s" --cores $numCores" +
@@ -138,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend(
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.get.split('/').last.split('.').head
+ val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString)
command.setValue(
s"cd $basename*; $prefixEnv " +
"./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" +
- s" --driver-url $driverUrl" +
- s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --driver-url $driverURL" +
+ s" --executor-id $executorId" +
s" --hostname ${offer.getHostname}" +
s" --cores $numCores" +
s" --app-id $appId")
@@ -151,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend(
command.build()
}
+ protected def driverURL: String = {
+ if (conf.contains("spark.testing")) {
+ "driverURL"
+ } else {
+ sc.env.rpcEnv.uriOf(
+ SparkEnv.driverActorSystemName,
+ RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
+ }
+ }
+
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
@@ -168,15 +201,19 @@ private[spark] class CoarseMesosSchedulerBackend(
* unless we've already launched more than we wanted to.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- synchronized {
+ stateLock.synchronized {
val filters = Filters.newBuilder().setRefuseSeconds(5).build()
-
for (offer <- offers) {
- val slaveId = offer.getSlaveId.toString
+ val offerAttributes = toAttributeMap(offer.getAttributesList)
+ val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
+ val slaveId = offer.getSlaveId.getValue
val mem = getResource(offer.getResourcesList, "mem")
val cpus = getResource(offer.getResourcesList, "cpus").toInt
- if (totalCoresAcquired < maxCores &&
- mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ val id = offer.getId.getValue
+ if (taskIdToSlaveId.size < executorLimit &&
+ totalCoresAcquired < maxCores &&
+ meetsConstraints &&
+ mem >= calculateTotalMemory(sc) &&
cpus >= 1 &&
failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
!slaveIdsWithExecutors.contains(slaveId)) {
@@ -187,45 +224,44 @@ private[spark] class CoarseMesosSchedulerBackend(
taskIdToSlaveId(taskId) = slaveId
slaveIdsWithExecutors += slaveId
coresByTaskId(taskId) = cpusToUse
- val task = MesosTaskInfo.newBuilder()
+ // Gather cpu resources from the available resources and use them in the task.
+ val (remainingResources, cpuResourcesToUse) =
+ partitionResources(offer.getResourcesList, "cpus", cpusToUse)
+ val (_, memResourcesToUse) =
+ partitionResources(remainingResources, "mem", calculateTotalMemory(sc))
+ val taskBuilder = MesosTaskInfo.newBuilder()
.setTaskId(TaskID.newBuilder().setValue(taskId.toString).build())
.setSlaveId(offer.getSlaveId)
- .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
+ .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId))
.setName("Task " + taskId)
- .addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem",
- MemoryUtils.calculateTotalMemory(sc)))
+ .addAllResources(cpuResourcesToUse)
+ .addAllResources(memResourcesToUse)
sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
MesosSchedulerBackendUtil
- .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder())
+ .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder())
}
+ // accept the offer and launch the task
+ logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
d.launchTasks(
- Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters)
+ Collections.singleton(offer.getId),
+ Collections.singleton(taskBuilder.build()), filters)
} else {
- // Filter it out
- d.launchTasks(
- Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters)
+ // Decline the offer
+ logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
+ d.declineOffer(offer.getId)
}
}
}
}
- /** Build a Mesos resource protobuf object */
- private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
- Resource.newBuilder()
- .setName(resourceName)
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
- .build()
- }
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
val taskId = status.getTaskId.getValue.toInt
val state = status.getState
- logInfo("Mesos task " + taskId + " is now " + state)
- synchronized {
+ logInfo(s"Mesos task $taskId is now $state")
+ stateLock.synchronized {
if (TaskState.isFinished(TaskState.fromMesos(state))) {
val slaveId = taskIdToSlaveId(taskId)
slaveIdsWithExecutors -= slaveId
@@ -239,18 +275,19 @@ private[spark] class CoarseMesosSchedulerBackend(
if (TaskState.isFailed(TaskState.fromMesos(state))) {
failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1
if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) {
- logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " +
+ logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " +
"is Spark installed on it?")
}
}
+ executorTerminated(d, slaveId, s"Executor finished with state $state")
// In case we'd rejected everything before but have now lost a node
- mesosDriver.reviveOffers()
+ d.reviveOffers()
}
}
}
override def error(d: SchedulerDriver, message: String) {
- logError("Mesos error: " + message)
+ logError(s"Mesos error: $message")
scheduler.error(message)
}
@@ -263,18 +300,39 @@ private[spark] class CoarseMesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
- override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
- logInfo("Mesos slave lost: " + slaveId.getValue)
- synchronized {
- if (slaveIdsWithExecutors.contains(slaveId.getValue)) {
- // Note that the slave ID corresponds to the executor ID on that slave
- slaveIdsWithExecutors -= slaveId.getValue
- removeExecutor(slaveId.getValue, "Mesos slave lost")
+ /**
+ * Called when a slave is lost or a Mesos task finished. Update local view on
+ * what tasks are running and remove the terminated slave from the list of pending
+ * slave IDs that we might have asked to be killed. It also notifies the driver
+ * that an executor was removed.
+ */
+ private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = {
+ stateLock.synchronized {
+ if (slaveIdsWithExecutors.contains(slaveId)) {
+ val slaveIdToTaskId = taskIdToSlaveId.inverse()
+ if (slaveIdToTaskId.contains(slaveId)) {
+ val taskId: Int = slaveIdToTaskId.get(slaveId)
+ taskIdToSlaveId.remove(taskId)
+ removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason)
+ }
+ // TODO: This assumes one Spark executor per Mesos slave,
+ // which may no longer be true after SPARK-5095
+ pendingRemovedSlaveIds -= slaveId
+ slaveIdsWithExecutors -= slaveId
}
}
}
- override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
+ private def sparkExecutorId(slaveId: String, taskId: String): String = {
+ s"$slaveId/$taskId"
+ }
+
+ override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = {
+ logInfo(s"Mesos slave lost: ${slaveId.getValue}")
+ executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue)
+ }
+
+ override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = {
logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
slaveLost(d, s)
}
@@ -285,4 +343,34 @@ private[spark] class CoarseMesosSchedulerBackend(
super.applicationId
}
+ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ // We don't truly know if we can fulfill the full amount of executors
+ // since at coarse grain it depends on the amount of slaves available.
+ logInfo("Capping the total amount of executors to " + requestedTotal)
+ executorLimitOption = Some(requestedTotal)
+ true
+ }
+
+ override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ if (mesosDriver == null) {
+ logWarning("Asked to kill executors before the Mesos driver was started.")
+ return false
+ }
+
+ val slaveIdToTaskId = taskIdToSlaveId.inverse()
+ for (executorId <- executorIds) {
+ val slaveId = executorId.split("/")(0)
+ if (slaveIdToTaskId.contains(slaveId)) {
+ mesosDriver.killTask(
+ TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build())
+ pendingRemovedSlaveIds += slaveId
+ } else {
+ logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler")
+ }
+ }
+ // no need to adjust `executorLimitOption` since the AllocationManager already communicated
+ // the desired limit through a call to `doRequestTotalExecutors`.
+ // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]]
+ true
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 1067a7f1caf4c..f078547e71352 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable
import org.apache.mesos.Protos.TaskStatus.Reason
import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
import org.apache.mesos.{Scheduler, SchedulerDriver}
+
import org.apache.spark.deploy.mesos.MesosDriverDescription
import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse}
import org.apache.spark.metrics.MetricsSystem
@@ -294,20 +295,24 @@ private[spark] class MesosClusterScheduler(
def start(): Unit = {
// TODO: Implement leader election to make sure only one framework running in the cluster.
val fwId = schedulerState.fetch[String]("frameworkId")
- val builder = FrameworkInfo.newBuilder()
- .setUser(Utils.getCurrentUserName())
- .setName(appName)
- .setWebuiUrl(frameworkUrl)
- .setCheckpoint(true)
- .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash
fwId.foreach { id =>
- builder.setId(FrameworkID.newBuilder().setValue(id).build())
frameworkId = id
}
recoverState()
metricsSystem.registerSource(new MesosClusterSchedulerSource(this))
metricsSystem.start()
- startScheduler(master, MesosClusterScheduler.this, builder.build())
+ val driver = createSchedulerDriver(
+ master,
+ MesosClusterScheduler.this,
+ Utils.getCurrentUserName(),
+ appName,
+ conf,
+ Some(frameworkUrl),
+ Some(true),
+ Some(Integer.MAX_VALUE),
+ fwId)
+
+ startScheduler(driver)
ready = true
}
@@ -448,12 +453,8 @@ private[spark] class MesosClusterScheduler(
offer.cpu -= driverCpu
offer.mem -= driverMem
val taskId = TaskID.newBuilder().setValue(submission.submissionId).build()
- val cpuResource = Resource.newBuilder()
- .setName("cpus").setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build()
- val memResource = Resource.newBuilder()
- .setName("mem").setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build()
+ val cpuResource = createResource("cpus", driverCpu)
+ val memResource = createResource("mem", driverMem)
val commandInfo = buildDriverCommand(submission)
val appName = submission.schedulerProperties("spark.app.name")
val taskInfo = TaskInfo.newBuilder()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 49de85ef48ada..3f63ec1c5832f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -23,14 +23,15 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
+import org.apache.mesos.{Scheduler => MScheduler, _}
import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.protobuf.ByteString
-import org.apache.mesos.{Scheduler => MScheduler, _}
+import org.apache.spark.{SparkContext, SparkException, TaskState}
import org.apache.spark.executor.MesosExecutorBackend
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkContext, SparkException, TaskState}
+
/**
* A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a
@@ -45,8 +46,8 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with MesosSchedulerUtils {
- // Which slave IDs we have executors on
- val slaveIdsWithExecutors = new HashSet[String]
+ // Stores the slave ids that has launched a Mesos executor.
+ val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo]
val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks
@@ -59,20 +60,33 @@ private[spark] class MesosSchedulerBackend(
private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1)
+ // Offer constraints
+ private[this] val slaveOfferConstraints =
+ parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
+
@volatile var appId: String = _
override def start() {
- val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
classLoader = Thread.currentThread.getContextClassLoader
- startScheduler(master, MesosSchedulerBackend.this, fwInfo)
+ val driver = createSchedulerDriver(
+ master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf)
+ startScheduler(driver)
}
- def createExecutorInfo(execId: String): MesosExecutorInfo = {
+ /**
+ * Creates a MesosExecutorInfo that is used to launch a Mesos executor.
+ * @param availableResources Available resources that is offered by Mesos
+ * @param execId The executor id to assign to this new executor.
+ * @return A tuple of the new mesos executor info and the remaining available resources.
+ */
+ def createExecutorInfo(
+ availableResources: JList[Resource],
+ execId: String): (MesosExecutorInfo, JList[Resource]) = {
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
- throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
- }
+ throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
+ }
val environment = Environment.newBuilder()
sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
environment.addVariables(
@@ -111,32 +125,25 @@ private[spark] class MesosSchedulerBackend(
command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get))
}
- val cpus = Resource.newBuilder()
- .setName("cpus")
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder()
- .setValue(mesosExecutorCores).build())
- .build()
- val memory = Resource.newBuilder()
- .setName("mem")
- .setType(Value.Type.SCALAR)
- .setScalar(
- Value.Scalar.newBuilder()
- .setValue(MemoryUtils.calculateTotalMemory(sc)).build())
- .build()
- val executorInfo = MesosExecutorInfo.newBuilder()
+ val builder = MesosExecutorInfo.newBuilder()
+ val (resourcesAfterCpu, usedCpuResources) =
+ partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK)
+ val (resourcesAfterMem, usedMemResources) =
+ partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc))
+
+ builder.addAllResources(usedCpuResources)
+ builder.addAllResources(usedMemResources)
+ val executorInfo = builder
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
- .addResources(cpus)
- .addResources(memory)
sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
MesosSchedulerBackendUtil
.setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder())
}
- executorInfo.build()
+ (executorInfo.build(), resourcesAfterMem)
}
/**
@@ -179,6 +186,18 @@ private[spark] class MesosSchedulerBackend(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
+ private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = {
+ val builder = new StringBuilder
+ tasks.foreach { t =>
+ builder.append("Task id: ").append(t.getTaskId.getValue).append("\n")
+ .append("Slave id: ").append(t.getSlaveId.getValue).append("\n")
+ .append("Task resources: ").append(t.getResourcesList).append("\n")
+ .append("Executor resources: ").append(t.getExecutor.getResourcesList)
+ .append("---------------------------------------------\n")
+ }
+ builder.toString()
+ }
+
/**
* Method called by Mesos to offer resources on slaves. We respond by asking our active task sets
* for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
@@ -191,15 +210,33 @@ private[spark] class MesosSchedulerBackend(
val mem = getResource(o.getResourcesList, "mem")
val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
- (mem >= MemoryUtils.calculateTotalMemory(sc) &&
- // need at least 1 for executor, 1 for task
- cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) ||
- (slaveIdsWithExecutors.contains(slaveId) &&
- cpus >= scheduler.CPUS_PER_TASK)
+ val offerAttributes = toAttributeMap(o.getAttributesList)
+
+ // check if all constraints are satisfield
+ // 1. Attribute constraints
+ // 2. Memory requirements
+ // 3. CPU requirements - need at least 1 for executor, 1 for task
+ val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
+ val meetsMemoryRequirements = mem >= calculateTotalMemory(sc)
+ val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)
+
+ val meetsRequirements =
+ (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) ||
+ (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK)
+
+ // add some debug messaging
+ val debugstr = if (meetsRequirements) "Accepting" else "Declining"
+ val id = o.getId.getValue
+ logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
+
+ meetsRequirements
}
+ // Decline offers we ruled out immediately
+ unUsableOffers.foreach(o => d.declineOffer(o.getId))
+
val workerOffers = usableOffers.map { o =>
- val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
+ val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) {
getResource(o.getResourcesList, "cpus").toInt
} else {
// If the Mesos executor has not been started on this slave yet, set aside a few
@@ -214,6 +251,10 @@ private[spark] class MesosSchedulerBackend(
val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
+ val slaveIdToResources = new HashMap[String, JList[Resource]]()
+ usableOffers.foreach { o =>
+ slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList
+ }
val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
@@ -225,11 +266,15 @@ private[spark] class MesosSchedulerBackend(
.foreach { offer =>
offer.foreach { taskDesc =>
val slaveId = taskDesc.executorId
- slaveIdsWithExecutors += slaveId
slavesIdsOfAcceptedOffers += slaveId
taskIdToSlaveId(taskDesc.taskId) = slaveId
+ val (mesosTask, remainingResources) = createMesosTask(
+ taskDesc,
+ slaveIdToResources(slaveId),
+ slaveId)
mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
- .add(createMesosTask(taskDesc, slaveId))
+ .add(mesosTask)
+ slaveIdToResources(slaveId) = remainingResources
}
}
@@ -242,6 +287,7 @@ private[spark] class MesosSchedulerBackend(
// TODO: Add support for log urls for Mesos
new ExecutorInfo(o.host, o.cores, Map.empty)))
)
+ logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}")
d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
@@ -250,28 +296,32 @@ private[spark] class MesosSchedulerBackend(
for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
d.declineOffer(o.getId)
}
-
- // Decline offers we ruled out immediately
- unUsableOffers.foreach(o => d.declineOffer(o.getId))
}
}
- /** Turn a Spark TaskDescription into a Mesos task */
- def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = {
+ /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */
+ def createMesosTask(
+ task: TaskDescription,
+ resources: JList[Resource],
+ slaveId: String): (MesosTaskInfo, JList[Resource]) = {
val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build()
- val cpuResource = Resource.newBuilder()
- .setName("cpus")
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build())
- .build()
- MesosTaskInfo.newBuilder()
+ val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) {
+ (slaveIdToExecutorInfo(slaveId), resources)
+ } else {
+ createExecutorInfo(resources, slaveId)
+ }
+ slaveIdToExecutorInfo(slaveId) = executorInfo
+ val (finalResources, cpuResources) =
+ partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK)
+ val taskInfo = MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
- .setExecutor(createExecutorInfo(slaveId))
+ .setExecutor(executorInfo)
.setName(task.name)
- .addResources(cpuResource)
+ .addAllResources(cpuResources)
.setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
+ (taskInfo, finalResources)
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
@@ -317,7 +367,7 @@ private[spark] class MesosSchedulerBackend(
private def removeExecutor(slaveId: String, reason: String) = {
synchronized {
listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason))
- slaveIdsWithExecutors -= slaveId
+ slaveIdToExecutorInfo -= slaveId
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index d11228f3d016a..c04920e4f5873 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -17,16 +17,21 @@
package org.apache.spark.scheduler.cluster.mesos
-import java.util.List
+import java.util.{List => JList}
import java.util.concurrent.CountDownLatch
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
-import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status}
-import org.apache.mesos.{MesosSchedulerDriver, Scheduler}
-import org.apache.spark.Logging
+import com.google.common.base.Splitter
+import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos}
+import org.apache.mesos.Protos._
+import org.apache.mesos.protobuf.{ByteString, GeneratedMessage}
+import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext}
import org.apache.spark.util.Utils
+
/**
* Shared trait for implementing a Mesos Scheduler. This holds common state and helper
* methods and Mesos scheduler will use.
@@ -36,16 +41,66 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
private final val registerLatch = new CountDownLatch(1)
// Driver for talking to Mesos
- protected var mesosDriver: MesosSchedulerDriver = null
+ protected var mesosDriver: SchedulerDriver = null
/**
- * Starts the MesosSchedulerDriver with the provided information. This method returns
- * only after the scheduler has registered with Mesos.
- * @param masterUrl Mesos master connection URL
- * @param scheduler Scheduler object
- * @param fwInfo FrameworkInfo to pass to the Mesos master
+ * Creates a new MesosSchedulerDriver that communicates to the Mesos master.
+ * @param masterUrl The url to connect to Mesos master
+ * @param scheduler the scheduler class to receive scheduler callbacks
+ * @param sparkUser User to impersonate with when running tasks
+ * @param appName The framework name to display on the Mesos UI
+ * @param conf Spark configuration
+ * @param webuiUrl The WebUI url to link from Mesos UI
+ * @param checkpoint Option to checkpoint tasks for failover
+ * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect
+ * @param frameworkId The id of the new framework
*/
- def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = {
+ protected def createSchedulerDriver(
+ masterUrl: String,
+ scheduler: Scheduler,
+ sparkUser: String,
+ appName: String,
+ conf: SparkConf,
+ webuiUrl: Option[String] = None,
+ checkpoint: Option[Boolean] = None,
+ failoverTimeout: Option[Double] = None,
+ frameworkId: Option[String] = None): SchedulerDriver = {
+ val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName)
+ val credBuilder = Credential.newBuilder()
+ webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) }
+ checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) }
+ failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) }
+ frameworkId.foreach { id =>
+ fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build())
+ }
+ conf.getOption("spark.mesos.principal").foreach { principal =>
+ fwInfoBuilder.setPrincipal(principal)
+ credBuilder.setPrincipal(principal)
+ }
+ conf.getOption("spark.mesos.secret").foreach { secret =>
+ credBuilder.setSecret(ByteString.copyFromUtf8(secret))
+ }
+ if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) {
+ throw new SparkException(
+ "spark.mesos.principal must be configured when spark.mesos.secret is set")
+ }
+ conf.getOption("spark.mesos.role").foreach { role =>
+ fwInfoBuilder.setRole(role)
+ }
+ if (credBuilder.hasPrincipal) {
+ new MesosSchedulerDriver(
+ scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build())
+ } else {
+ new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl)
+ }
+ }
+
+ /**
+ * Starts the MesosSchedulerDriver and stores the current running driver to this new instance.
+ * This driver is expected to not be running.
+ * This method returns only after the scheduler has registered with Mesos.
+ */
+ def startScheduler(newDriver: SchedulerDriver): Unit = {
synchronized {
if (mesosDriver != null) {
registerLatch.await()
@@ -56,11 +111,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
setDaemon(true)
override def run() {
- mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl)
+ mesosDriver = newDriver
try {
val ret = mesosDriver.run()
logInfo("driver.run() returned with code " + ret)
- if (ret.equals(Status.DRIVER_ABORTED)) {
+ if (ret != null && ret.equals(Status.DRIVER_ABORTED)) {
System.exit(1)
}
} catch {
@@ -79,17 +134,201 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
/**
* Signal that the scheduler has registered with Mesos.
*/
+ protected def getResource(res: JList[Resource], name: String): Double = {
+ // A resource can have multiple values in the offer since it can either be from
+ // a specific role or wildcard.
+ res.filter(_.getName == name).map(_.getScalar.getValue).sum
+ }
+
protected def markRegistered(): Unit = {
registerLatch.countDown()
}
+ def createResource(name: String, amount: Double, role: Option[String] = None): Resource = {
+ val builder = Resource.newBuilder()
+ .setName(name)
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(amount).build())
+
+ role.foreach { r => builder.setRole(r) }
+
+ builder.build()
+ }
+
+ /**
+ * Partition the existing set of resources into two groups, those remaining to be
+ * scheduled and those requested to be used for a new task.
+ * @param resources The full list of available resources
+ * @param resourceName The name of the resource to take from the available resources
+ * @param amountToUse The amount of resources to take from the available resources
+ * @return The remaining resources list and the used resources list.
+ */
+ def partitionResources(
+ resources: JList[Resource],
+ resourceName: String,
+ amountToUse: Double): (List[Resource], List[Resource]) = {
+ var remain = amountToUse
+ var requestedResources = new ArrayBuffer[Resource]
+ val remainingResources = resources.map {
+ case r => {
+ if (remain > 0 &&
+ r.getType == Value.Type.SCALAR &&
+ r.getScalar.getValue > 0.0 &&
+ r.getName == resourceName) {
+ val usage = Math.min(remain, r.getScalar.getValue)
+ requestedResources += createResource(resourceName, usage, Some(r.getRole))
+ remain -= usage
+ createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole))
+ } else {
+ r
+ }
+ }
+ }
+
+ // Filter any resource that has depleted.
+ val filteredResources =
+ remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0)
+
+ (filteredResources.toList, requestedResources.toList)
+ }
+
+ /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */
+ protected def getAttribute(attr: Attribute): (String, Set[String]) = {
+ (attr.getName, attr.getText.getValue.split(',').toSet)
+ }
+
+
+ /** Build a Mesos resource protobuf object */
+ protected def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ Resource.newBuilder()
+ .setName(resourceName)
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
+ .build()
+ }
+
+ /**
+ * Converts the attributes from the resource offer into a Map of name -> Attribute Value
+ * The attribute values are the mesos attribute types and they are
+ * @param offerAttributes
+ * @return
+ */
+ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = {
+ offerAttributes.map(attr => {
+ val attrValue = attr.getType match {
+ case Value.Type.SCALAR => attr.getScalar
+ case Value.Type.RANGES => attr.getRanges
+ case Value.Type.SET => attr.getSet
+ case Value.Type.TEXT => attr.getText
+ }
+ (attr.getName, attrValue)
+ }).toMap
+ }
+
+
+ /**
+ * Match the requirements (if any) to the offer attributes.
+ * if attribute requirements are not specified - return true
+ * else if attribute is defined and no values are given, simple attribute presence is performed
+ * else if attribute name and value is specified, subset match is performed on slave attributes
+ */
+ def matchesAttributeRequirements(
+ slaveOfferConstraints: Map[String, Set[String]],
+ offerAttributes: Map[String, GeneratedMessage]): Boolean = {
+ slaveOfferConstraints.forall {
+ // offer has the required attribute and subsumes the required values for that attribute
+ case (name, requiredValues) =>
+ offerAttributes.get(name) match {
+ case None => false
+ case Some(_) if requiredValues.isEmpty => true // empty value matches presence
+ case Some(scalarValue: Value.Scalar) =>
+ // check if provided values is less than equal to the offered values
+ requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue)
+ case Some(rangeValue: Value.Range) =>
+ val offerRange = rangeValue.getBegin to rangeValue.getEnd
+ // Check if there is some required value that is between the ranges specified
+ // Note: We only support the ability to specify discrete values, in the future
+ // we may expand it to subsume ranges specified with a XX..YY value or something
+ // similar to that.
+ requiredValues.map(_.toLong).exists(offerRange.contains(_))
+ case Some(offeredValue: Value.Set) =>
+ // check if the specified required values is a subset of offered set
+ requiredValues.subsetOf(offeredValue.getItemList.toSet)
+ case Some(textValue: Value.Text) =>
+ // check if the specified value is equal, if multiple values are specified
+ // we succeed if any of them match.
+ requiredValues.contains(textValue.getValue)
+ }
+ }
+ }
+
/**
- * Get the amount of resources for the specified type from the resource list
+ * Parses the attributes constraints provided to spark and build a matching data struct:
+ * Map[, Set[values-to-match]]
+ * The constraints are specified as ';' separated key-value pairs where keys and values
+ * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for
+ * multiple values (comma separated). For example:
+ * {{{
+ * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b")
+ * // would result in
+ *
+ * Map(
+ * "tachyon" -> Set("true"),
+ * "zone": -> Set("us-east-1a", "us-east-1b")
+ * )
+ * }}}
+ *
+ * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/
+ * https://github.com/apache/mesos/blob/master/src/common/values.cpp
+ * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp
+ *
+ * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated
+ * by ':')
+ * @return Map of constraints to match resources offers.
*/
- protected def getResource(res: List[Resource], name: String): Double = {
- for (r <- res if r.getName == name) {
- return r.getScalar.getValue
+ def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = {
+ /*
+ Based on mesos docs:
+ attributes : attribute ( ";" attribute )*
+ attribute : labelString ":" ( labelString | "," )+
+ labelString : [a-zA-Z0-9_/.-]
+ */
+ val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':')
+ // kv splitter
+ if (constraintsVal.isEmpty) {
+ Map()
+ } else {
+ try {
+ Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map {
+ case (k, v) =>
+ if (v == null || v.isEmpty) {
+ (k, Set[String]())
+ } else {
+ (k, v.split(',').toSet)
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e)
+ }
}
- 0.0
}
+
+ // These defaults copied from YARN
+ private val MEMORY_OVERHEAD_FRACTION = 0.10
+ private val MEMORY_OVERHEAD_MINIMUM = 384
+
+ /**
+ * Return the amount of memory to allocate to each executor, taking into account
+ * container overheads.
+ * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value
+ * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM
+ * (whichever is larger)
+ */
+ def calculateTotalMemory(sc: SparkContext): Int = {
+ sc.conf.getInt("spark.mesos.executor.memoryOverhead",
+ math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) +
+ sc.executorMemory
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 3078a1b10be8b..4d48fcfea44e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -17,13 +17,16 @@
package org.apache.spark.scheduler.local
+import java.io.File
+import java.net.URL
import java.nio.ByteBuffer
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
-import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.ExecutorInfo
private case class ReviveOffers()
@@ -40,6 +43,7 @@ private case class StopExecutor()
*/
private[spark] class LocalEndpoint(
override val rpcEnv: RpcEnv,
+ userClassPath: Seq[URL],
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int)
@@ -47,11 +51,11 @@ private[spark] class LocalEndpoint(
private var freeCores = totalCores
- private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
- private val localExecutorHostname = "localhost"
+ val localExecutorId = SparkContext.DRIVER_IDENTIFIER
+ val localExecutorHostname = "localhost"
private val executor = new Executor(
- localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
+ localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
override def receive: PartialFunction[Any, Unit] = {
case ReviveOffers =>
@@ -96,11 +100,28 @@ private[spark] class LocalBackend(
extends SchedulerBackend with ExecutorBackend with Logging {
private val appId = "local-" + System.currentTimeMillis
- var localEndpoint: RpcEndpointRef = null
+ private var localEndpoint: RpcEndpointRef = null
+ private val userClassPath = getUserClasspath(conf)
+ private val listenerBus = scheduler.sc.listenerBus
+
+ /**
+ * Returns a list of URLs representing the user classpath.
+ *
+ * @param conf Spark configuration.
+ */
+ def getUserClasspath(conf: SparkConf): Seq[URL] = {
+ val userClassPathStr = conf.getOption("spark.executor.extraClassPath")
+ userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
+ }
override def start() {
- localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint(
- "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores))
+ val rpcEnv = SparkEnv.get.rpcEnv
+ val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
+ localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint)
+ listenerBus.post(SparkListenerExecutorAdded(
+ System.currentTimeMillis,
+ executorEndpoint.localExecutorId,
+ new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
}
override def stop() {
diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
new file mode 100644
index 0000000000000..62f8aae7f2126
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+
+import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer}
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import org.apache.avro.{Schema, SchemaNormalization}
+import org.apache.avro.generic.{GenericData, GenericRecord}
+import org.apache.avro.io._
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.{SparkException, SparkEnv}
+import org.apache.spark.io.CompressionCodec
+
+/**
+ * Custom serializer used for generic Avro records. If the user registers the schemas
+ * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual
+ * schema, as to reduce network IO.
+ * Actions like parsing or compressing schemas are computationally expensive so the serializer
+ * caches all previously seen values as to reduce the amount of work needed to do.
+ * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the
+ * string representation of the Avro schema, used to decrease the amount of data
+ * that needs to be serialized.
+ */
+private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
+ extends KSerializer[GenericRecord] {
+
+ /** Used to reduce the amount of effort to compress the schema */
+ private val compressCache = new mutable.HashMap[Schema, Array[Byte]]()
+ private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]()
+
+ /** Reuses the same datum reader/writer since the same schema will be used many times */
+ private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]()
+ private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]()
+
+ /** Fingerprinting is very expensive so this alleviates most of the work */
+ private val fingerprintCache = new mutable.HashMap[Schema, Long]()
+ private val schemaCache = new mutable.HashMap[Long, Schema]()
+
+ // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become
+ // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make
+ // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having
+ // the SparkEnv set (note those tests would fail if they tried to serialize avro data).
+ private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+
+ /**
+ * Used to compress Schemas when they are being sent over the wire.
+ * The compression results are memoized to reduce the compression time since the
+ * same schema is compressed many times over
+ */
+ def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
+ val bos = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(bos)
+ out.write(schema.toString.getBytes("UTF-8"))
+ out.close()
+ bos.toByteArray
+ })
+
+ /**
+ * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already
+ * seen values so to limit the number of times that decompression has to be done.
+ */
+ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
+ val bis = new ByteArrayInputStream(schemaBytes.array())
+ val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
+ new Schema.Parser().parse(new String(bytes, "UTF-8"))
+ })
+
+ /**
+ * Serializes a record to the given output stream. It caches a lot of the internal data as
+ * to not redo work
+ */
+ def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = {
+ val encoder = EncoderFactory.get.binaryEncoder(output, null)
+ val schema = datum.getSchema
+ val fingerprint = fingerprintCache.getOrElseUpdate(schema, {
+ SchemaNormalization.parsingFingerprint64(schema)
+ })
+ schemas.get(fingerprint) match {
+ case Some(_) =>
+ output.writeBoolean(true)
+ output.writeLong(fingerprint)
+ case None =>
+ output.writeBoolean(false)
+ val compressedSchema = compress(schema)
+ output.writeInt(compressedSchema.length)
+ output.writeBytes(compressedSchema)
+ }
+
+ writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema))
+ .asInstanceOf[DatumWriter[R]]
+ .write(datum, encoder)
+ encoder.flush()
+ }
+
+ /**
+ * Deserializes generic records into their in-memory form. There is internal
+ * state to keep a cache of already seen schemas and datum readers.
+ */
+ def deserializeDatum(input: KryoInput): GenericRecord = {
+ val schema = {
+ if (input.readBoolean()) {
+ val fingerprint = input.readLong()
+ schemaCache.getOrElseUpdate(fingerprint, {
+ schemas.get(fingerprint) match {
+ case Some(s) => new Schema.Parser().parse(s)
+ case None =>
+ throw new SparkException(
+ "Error reading attempting to read avro data -- encountered an unknown " +
+ s"fingerprint: $fingerprint, not sure what schema to use. This could happen " +
+ "if you registered additional schemas after starting your spark context.")
+ }
+ })
+ } else {
+ val length = input.readInt()
+ decompress(ByteBuffer.wrap(input.readBytes(length)))
+ }
+ }
+ val decoder = DecoderFactory.get.directBinaryDecoder(input, null)
+ readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema))
+ .asInstanceOf[DatumReader[GenericRecord]]
+ .read(null, decoder)
+ }
+
+ override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit =
+ serializeDatum(datum, output)
+
+ override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord =
+ deserializeDatum(input)
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 698d1384d580d..4a5274b46b7a0 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
extends DeserializationStream {
private val objIn = new ObjectInputStream(in) {
- override def resolveClass(desc: ObjectStreamClass): Class[_] =
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ // scalastyle:off classforname
Class.forName(desc.getName, false, loader)
+ // scalastyle:on classforname
+ }
}
def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index ed35cffe968f8..0ff7562e912ca 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,6 +27,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
+import org.apache.avro.generic.{GenericData, GenericRecord}
import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap}
import org.apache.spark._
@@ -73,6 +74,8 @@ class KryoSerializer(conf: SparkConf)
.split(',')
.filter(!_.isEmpty)
+ private val avroSchemas = conf.getAvroSchema
+
def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
def newKryo(): Kryo = {
@@ -101,7 +104,11 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
+ kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
+ kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
+
try {
+ // scalastyle:off classforname
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
// Register classes given through spark.kryo.classesToRegister.
@@ -111,6 +118,7 @@ class KryoSerializer(conf: SparkConf)
userRegistrator
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
.foreach { reg => reg.registerClasses(kryo) }
+ // scalastyle:on classforname
} catch {
case e: Exception =>
throw new SparkException(s"Failed to register classes with Kryo", e)
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index cc2f0506817d3..a1b1e1631eafb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging {
/** ObjectStreamClass$ClassDataSlot.desc field */
val DescField: Field = {
+ // scalastyle:off classforname
val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
+ // scalastyle:on classforname
f.setAccessible(true)
f
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6c3b3080d2605..f6a96d81e7aa9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
- val writers: Array[BlockObjectWriter]
+ val writers: Array[DiskBlockObjectWriter]
/** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
def releaseWriters(success: Boolean)
@@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
- val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
writeMetrics)
}
} else {
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d9c63b6e7bbb9..fae69551e7330 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
private[spark] object IndexShuffleBlockResolver {
- // No-op reduce ID used in interactions with disk store and BlockObjectWriter.
+ // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter.
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
// shuffle outputs for several reduces are glommed into a single file.
// TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index 3bcc7178a3d8b..f038b722957b8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -19,95 +19,101 @@ package org.apache.spark.shuffle
import scala.collection.mutable
-import org.apache.spark.{Logging, SparkException, SparkConf}
+import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
/**
- * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
+ * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
- * This class tries to ensure that each thread gets a reasonable share of memory, instead of some
- * thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
- * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
+ * This class tries to ensure that each task gets a reasonable share of memory, instead of some
+ * task ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
- * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
+ * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
- private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
+ private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes
def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
+ private def currentTaskAttemptId(): Long = {
+ // In case this is called on the driver, return an invalid task attempt id.
+ Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+ }
+
/**
- * Try to acquire up to numBytes memory for the current thread, and return the number of bytes
+ * Try to acquire up to numBytes memory for the current task, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
- * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
- * total memory pool (where N is the # of active threads) before it is forced to spill. This can
- * happen if the number of threads increases but an older thread had a lot of memory already.
+ * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
+ * total memory pool (where N is the # of active tasks) before it is forced to spill. This can
+ * happen if the number of tasks increases but an older task had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = synchronized {
- val threadId = Thread.currentThread().getId
+ val taskAttemptId = currentTaskAttemptId()
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
- // Add this thread to the threadMemory map just so we can keep an accurate count of the number
- // of active threads, to let other threads ramp down their memory in calls to tryToAcquire
- if (!threadMemory.contains(threadId)) {
- threadMemory(threadId) = 0L
- notifyAll() // Will later cause waiting threads to wake up and check numThreads again
+ // Add this task to the taskMemory map just so we can keep an accurate count of the number
+ // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
+ if (!taskMemory.contains(taskAttemptId)) {
+ taskMemory(taskAttemptId) = 0L
+ notifyAll() // Will later cause waiting tasks to wake up and check numThreads again
}
// Keep looping until we're either sure that we don't want to grant this request (because this
- // thread would have more than 1 / numActiveThreads of the memory) or we have enough free
- // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
+ // task would have more than 1 / numActiveTasks of the memory) or we have enough free
+ // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
while (true) {
- val numActiveThreads = threadMemory.keys.size
- val curMem = threadMemory(threadId)
- val freeMemory = maxMemory - threadMemory.values.sum
+ val numActiveTasks = taskMemory.keys.size
+ val curMem = taskMemory(taskAttemptId)
+ val freeMemory = maxMemory - taskMemory.values.sum
- // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
+ // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
- val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
+ val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
- if (curMem < maxMemory / (2 * numActiveThreads)) {
- // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
- // if we can't give it this much now, wait for other threads to free up memory
- // (this happens if older threads allocated lots of memory before N grew)
- if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
+ if (curMem < maxMemory / (2 * numActiveTasks)) {
+ // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+ // if we can't give it this much now, wait for other tasks to free up memory
+ // (this happens if older tasks allocated lots of memory before N grew)
+ if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
- threadMemory(threadId) += toGrant
+ taskMemory(taskAttemptId) += toGrant
return toGrant
} else {
- logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
+ logInfo(
+ s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
- threadMemory(threadId) += toGrant
+ taskMemory(taskAttemptId) += toGrant
return toGrant
}
}
0L // Never reached
}
- /** Release numBytes bytes for the current thread. */
+ /** Release numBytes bytes for the current task. */
def release(numBytes: Long): Unit = synchronized {
- val threadId = Thread.currentThread().getId
- val curMem = threadMemory.getOrElse(threadId, 0L)
+ val taskAttemptId = currentTaskAttemptId()
+ val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
- s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
+ s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
}
- threadMemory(threadId) -= numBytes
+ taskMemory(taskAttemptId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
- /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
- def releaseMemoryForThisThread(): Unit = synchronized {
- val threadId = Thread.currentThread().getId
- threadMemory.remove(threadId)
+ /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
+ def releaseMemoryForThisTask(): Unit = synchronized {
+ val taskAttemptId = currentTaskAttemptId()
+ taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
deleted file mode 100644
index 9d8e7e9f03aea..0000000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import java.io.InputStream
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.util.{Failure, Success}
-
-import org.apache.spark._
-import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
- ShuffleBlockId}
-
-private[hash] object BlockStoreShuffleFetcher extends Logging {
- def fetchBlockStreams(
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- blockManager: BlockManager,
- mapOutputTracker: MapOutputTracker)
- : Iterator[(BlockId, InputStream)] =
- {
- logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-
- val startTime = System.currentTimeMillis
- val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
- logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
-
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
- for (((address, size), index) <- statuses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
- }
-
- val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
- case (address, splits) =>
- (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
- }
-
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
- context,
- blockManager.shuffleClient,
- blockManager,
- blocksByAddress,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-
- // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
- blockFetcherItr.map { blockPair =>
- val blockId = blockPair._1
- val blockOption = blockPair._2
- blockOption match {
- case Success(inputStream) => {
- (blockId, inputStream)
- }
- case Failure(e) => {
- blockId match {
- case ShuffleBlockId(shufId, mapId, _) =>
- val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
- case _ =>
- throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block", e)
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index d5c9880659dd3..de79fa56f017b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,10 +17,10 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C](
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
- extends ShuffleReader[K, C]
-{
+ extends ShuffleReader[K, C] with Logging {
+
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")
@@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
- handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
- val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+ val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index eb87cee15903c..41df70c602c30 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
private[spark] class HashShuffleWriter[K, V](
shuffleBlockResolver: FileShuffleBlockResolver,
@@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
+ val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
writer.commitAndClose()
writer.fileSegment().length
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1beafa1771448..86493673d958d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -648,7 +648,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
- writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
+ writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 68ed9096731c5..5dc0c537cbb62 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint(
register(blockManagerId, maxMemSize, slaveEndpoint)
context.reply(true)
- case UpdateBlockInfo(
+ case _updateBlockInfo @ UpdateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) =>
context.reply(updateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize))
+ listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo)))
case GetLocations(blockId) =>
context.reply(getLocations(blockId))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
new file mode 100644
index 0000000000000..2789e25b8d3ab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+private[spark] case class BlockUIData(
+ blockId: BlockId,
+ location: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ externalBlockStoreSize: Long)
+
+/**
+ * The aggregated status of stream blocks in an executor
+ */
+private[spark] case class ExecutorStreamBlockStatus(
+ executorId: String,
+ location: String,
+ blocks: Seq[BlockUIData]) {
+
+ def totalMemSize: Long = blocks.map(_.memSize).sum
+
+ def totalDiskSize: Long = blocks.map(_.diskSize).sum
+
+ def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum
+
+ def numStreamBlocks: Int = blocks.size
+
+}
+
+private[spark] class BlockStatusListener extends SparkListener {
+
+ private val blockManagers =
+ new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]]
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ if (!blockId.isInstanceOf[StreamBlockId]) {
+ // Now we only monitor StreamBlocks
+ return
+ }
+ val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize
+
+ synchronized {
+ // Drop the update info if the block manager is not registered
+ blockManagers.get(blockManagerId).foreach { blocksInBlockManager =>
+ if (storageLevel.isValid) {
+ blocksInBlockManager.put(blockId,
+ BlockUIData(
+ blockId,
+ blockManagerId.hostPort,
+ storageLevel,
+ memSize,
+ diskSize,
+ externalBlockStoreSize)
+ )
+ } else {
+ // If isValid is not true, it means we should drop the block.
+ blocksInBlockManager -= blockId
+ }
+ }
+ }
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
+ synchronized {
+ blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap())
+ }
+ }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized {
+ blockManagers -= blockManagerRemoved.blockManagerId
+ }
+
+ def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized {
+ blockManagers.map { case (blockManagerId, blocks) =>
+ ExecutorStreamBlockStatus(
+ blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq)
+ }.toSeq
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
new file mode 100644
index 0000000000000..a5790e4454a89
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo
+
+/**
+ * :: DeveloperApi ::
+ * Stores information about a block status in a block manager.
+ */
+@DeveloperApi
+case class BlockUpdatedInfo(
+ blockManagerId: BlockManagerId,
+ blockId: BlockId,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ externalBlockStoreSize: Long)
+
+private[spark] object BlockUpdatedInfo {
+
+ private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = {
+ BlockUpdatedInfo(
+ updateBlockInfo.blockManagerId,
+ updateBlockInfo.blockId,
+ updateBlockInfo.storageLevel,
+ updateBlockInfo.memSize,
+ updateBlockInfo.diskSize,
+ updateBlockInfo.externalBlockStoreSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 91ef86389a0c3..5f537692a16c5 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
(blockId, getFile(blockId))
}
+ /**
+ * Create local directories for storing block data. These directories are
+ * located inside configured local directories and won't
+ * be deleted on JVM exit when using the external shuffle service.
+ */
private def createLocalDirs(conf: SparkConf): Array[File] = {
- Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
+ Utils.getConfiguredLocalDirs(conf).flatMap { rootDir =>
try {
val localDir = Utils.createDirectory(rootDir, "blockmgr")
+ Utils.chmod700(localDir)
logInfo(s"Created local directory at $localDir")
Some(localDir)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
similarity index 83%
rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 7eeabd1e0489c..49d9154f95a5b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils
/**
- * An interface for writing JVM objects to some underlying storage. This interface allows
- * appending data to an existing block, and can guarantee atomicity in the case of faults
- * as it allows the caller to revert partial writes.
+ * A class for writing JVM objects directly to a file on disk. This class allows data to be appended
+ * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
+ * revert partial writes.
*
- * This interface does not support concurrent writes. Also, once the writer has
- * been opened, it cannot be reopened again.
- */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
-
- def open(): BlockObjectWriter
-
- def close()
-
- def isOpen: Boolean
-
- /**
- * Flush the partial writes and commit them as a single atomic block.
- */
- def commitAndClose(): Unit
-
- /**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
- * when there are runtime exceptions. This method will not throw, though it may be
- * unsuccessful in truncating written data.
- */
- def revertPartialWritesAndClose()
-
- /**
- * Writes a key-value pair.
- */
- def write(key: Any, value: Any)
-
- /**
- * Notify the writer that a record worth of bytes has been written with OutputStream#write.
- */
- def recordWritten()
-
- /**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment
-}
-
-/**
- * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
+ * This class does not support concurrent writes. Also, once the writer has been opened it cannot be
+ * reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- blockId: BlockId,
+ val blockId: BlockId,
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
- // These write metrics concurrently shared with other active BlockObjectWriter's who
+ // These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
- extends BlockObjectWriter(blockId)
- with Logging
-{
+ extends OutputStream
+ with Logging {
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
- override def open(): BlockObjectWriter = {
+ def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
@@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def isOpen: Boolean = objOut != null
+ def isOpen: Boolean = objOut != null
- override def commitAndClose(): Unit = {
+ /**
+ * Flush the partial writes and commit them as a single atomic block.
+ */
+ def commitAndClose(): Unit = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
@@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter(
commitAndCloseHasBeenCalled = true
}
- // Discard current writes. We do this by flushing the outstanding writes and then
- // truncating the file to its initial position.
- override def revertPartialWritesAndClose() {
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions. This method will not throw, though it may be
+ * unsuccessful in truncating written data.
+ */
+ def revertPartialWritesAndClose() {
+ // Discard current writes. We do this by flushing the outstanding writes and then
+ // truncating the file to its initial position.
try {
if (initialized) {
writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
@@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(key: Any, value: Any) {
+ /**
+ * Writes a key-value pair.
+ */
+ def write(key: Any, value: Any) {
if (!initialized) {
open()
}
@@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter(
bs.write(kvBytes, offs, len)
}
- override def recordWritten(): Unit = {
+ /**
+ * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+ */
+ def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def fileSegment(): FileSegment = {
+ /**
+ * Returns the file segment of committed data that this Writer has written.
+ * This is only valid after commitAndClose() has been called.
+ */
+ def fileSegment(): FileSegment = {
if (!commitAndCloseHasBeenCalled) {
throw new IllegalStateException(
"fileSegment() is only valid after commitAndClose() has been called")
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
index 291394ed34816..db965d54bafd6 100644
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
@@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
.getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME)
try {
- val instance = Class.forName(clsName)
+ val instance = Utils.classForName(clsName)
.newInstance()
.asInstanceOf[ExternalBlockManager]
instance.init(blockManager, executorId)
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index ed609772e6979..6f27f00307f8c 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -23,6 +23,7 @@ import java.util.LinkedHashMap
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.TaskContext
import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
@@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Ensure only one thread is putting, and if necessary, dropping blocks at any given time
private val accountingLock = new Object
- // A mapping from thread ID to amount of memory used for unrolling a block (in bytes)
+ // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
// All accesses of this map are assumed to have manually synchronized on `accountingLock`
private val unrollMemoryMap = mutable.HashMap[Long, Long]()
// Same as `unrollMemoryMap`, but for pending unroll memory as defined below.
- // Pending unroll memory refers to the intermediate memory occupied by a thread
+ // Pending unroll memory refers to the intermediate memory occupied by a task
// after the unroll but before the actual putting of the block in the cache.
// This chunk of memory is expected to be released *as soon as* we finish
// caching the corresponding block as opposed to until after the task finishes.
@@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
var elementsUnrolled = 0
// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
- // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing.
+ // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing.
val initialMemoryThreshold = unrollMemoryThreshold
// How often to check whether we need to request more memory
val memoryCheckPeriod = 16
- // Memory currently reserved by this thread for this particular unrolling operation
+ // Memory currently reserved by this task for this particular unrolling operation
var memoryThreshold = initialMemoryThreshold
// Memory to request as a multiple of current vector size
val memoryGrowthFactor = 1.5
- // Previous unroll memory held by this thread, for releasing later (only at the very end)
- val previousMemoryReserved = currentUnrollMemoryForThisThread
+ // Previous unroll memory held by this task, for releasing later (only at the very end)
+ val previousMemoryReserved = currentUnrollMemoryForThisTask
// Underlying vector for unrolling the block
var vector = new SizeTrackingVector[Any]
// Request enough memory to begin unrolling
- keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold)
+ keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold)
if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
@@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Hold the accounting lock, in case another thread concurrently puts a block that
// takes up the unrolling space we just ensured here
accountingLock.synchronized {
- if (!reserveUnrollMemoryForThisThread(amountToRequest)) {
+ if (!reserveUnrollMemoryForThisTask(amountToRequest)) {
// If the first request is not granted, try again after ensuring free space
// If there is still not enough space, give up and drop the partition
val spaceToEnsure = maxUnrollMemory - currentUnrollMemory
@@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
val result = ensureFreeSpace(blockId, spaceToEnsure)
droppedBlocks ++= result.droppedBlocks
}
- keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest)
+ keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest)
}
}
// New threshold is currentSize * memoryGrowthFactor
@@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// later when the task finishes.
if (keepUnrolling) {
accountingLock.synchronized {
- val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved
- releaseUnrollMemoryForThisThread(amountToRelease)
- reservePendingUnrollMemoryForThisThread(amountToRelease)
+ val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved
+ releaseUnrollMemoryForThisTask(amountToRelease)
+ reservePendingUnrollMemoryForThisTask(amountToRelease)
}
}
}
@@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
// Release the unroll memory used because we no longer need the underlying Array
- releasePendingUnrollMemoryForThisThread()
+ releasePendingUnrollMemoryForThisTask()
}
ResultWithDroppedBlocks(putSuccess, droppedBlocks)
}
@@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Take into account the amount of memory currently occupied by unrolling blocks
// and minus the pending unroll memory for that block on current thread.
- val threadId = Thread.currentThread().getId
+ val taskAttemptId = currentTaskAttemptId()
val actualFreeMemory = freeMemory - currentUnrollMemory +
- pendingUnrollMemoryMap.getOrElse(threadId, 0L)
+ pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L)
if (actualFreeMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
@@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
logInfo(s"${selectedBlocks.size} blocks selected for dropping")
for (blockId <- selectedBlocks) {
val entry = entries.synchronized { entries.get(blockId) }
- // This should never be null as only one thread should be dropping
+ // This should never be null as only one task should be dropping
// blocks and removing entries. However the check is still here for
// future safety.
if (entry != null) {
@@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
entries.synchronized { entries.containsKey(blockId) }
}
+ private def currentTaskAttemptId(): Long = {
+ // In case this is called on the driver, return an invalid task attempt id.
+ Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+ }
+
/**
- * Reserve additional memory for unrolling blocks used by this thread.
+ * Reserve additional memory for unrolling blocks used by this task.
* Return whether the request is granted.
*/
- def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
+ def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
accountingLock.synchronized {
val granted = freeMemory > currentUnrollMemory + memory
if (granted) {
- val threadId = Thread.currentThread().getId
- unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory
+ val taskAttemptId = currentTaskAttemptId()
+ unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
}
granted
}
}
/**
- * Release memory used by this thread for unrolling blocks.
- * If the amount is not specified, remove the current thread's allocation altogether.
+ * Release memory used by this task for unrolling blocks.
+ * If the amount is not specified, remove the current task's allocation altogether.
*/
- def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
- val threadId = Thread.currentThread().getId
+ def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = {
+ val taskAttemptId = currentTaskAttemptId()
accountingLock.synchronized {
if (memory < 0) {
- unrollMemoryMap.remove(threadId)
+ unrollMemoryMap.remove(taskAttemptId)
} else {
- unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory
- // If this thread claims no more unroll memory, release it completely
- if (unrollMemoryMap(threadId) <= 0) {
- unrollMemoryMap.remove(threadId)
+ unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory
+ // If this task claims no more unroll memory, release it completely
+ if (unrollMemoryMap(taskAttemptId) <= 0) {
+ unrollMemoryMap.remove(taskAttemptId)
}
}
}
}
/**
- * Reserve the unroll memory of current unroll successful block used by this thread
+ * Reserve the unroll memory of current unroll successful block used by this task
* until actually put the block into memory entry.
*/
- def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = {
- val threadId = Thread.currentThread().getId
+ def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = {
+ val taskAttemptId = currentTaskAttemptId()
accountingLock.synchronized {
- pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory
+ pendingUnrollMemoryMap(taskAttemptId) =
+ pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
}
}
/**
- * Release pending unroll memory of current unroll successful block used by this thread
+ * Release pending unroll memory of current unroll successful block used by this task
*/
- def releasePendingUnrollMemoryForThisThread(): Unit = {
- val threadId = Thread.currentThread().getId
+ def releasePendingUnrollMemoryForThisTask(): Unit = {
+ val taskAttemptId = currentTaskAttemptId()
accountingLock.synchronized {
- pendingUnrollMemoryMap.remove(threadId)
+ pendingUnrollMemoryMap.remove(taskAttemptId)
}
}
/**
- * Return the amount of memory currently occupied for unrolling blocks across all threads.
+ * Return the amount of memory currently occupied for unrolling blocks across all tasks.
*/
def currentUnrollMemory: Long = accountingLock.synchronized {
unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum
}
/**
- * Return the amount of memory currently occupied for unrolling blocks by this thread.
+ * Return the amount of memory currently occupied for unrolling blocks by this task.
*/
- def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
- unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
+ def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized {
+ unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
}
/**
- * Return the number of threads currently unrolling blocks.
+ * Return the number of tasks currently unrolling blocks.
*/
- def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+ def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
/**
* Log information about current memory usage.
@@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
logInfo(
s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
- s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " +
+ s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " +
s"Storage limit = ${Utils.bytesToString(maxMemory)}."
)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index e49e39679e940..a759ceb96ec1e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -21,18 +21,19 @@ import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Try}
+import scala.util.control.NonFatal
-import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.{Logging, SparkException, TaskContext}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
- * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
@@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Try[InputStream])] with Logging {
+ extends Iterator[(BlockId, InputStream)] with Logging {
import ShuffleBlockFetcherIterator._
@@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator(
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
currentResult = null
@@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
}
@@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
@@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), e))
+ results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
@@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FailureFetchResult(blockId, e))
+ results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
@@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
/**
- * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
* underlying each InputStream will be freed by the cleanup() method registered with the
* TaskCompletionListener. However, callers should close() these InputStreams
* as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
*/
- override def next(): (BlockId, Try[InputStream]) = {
+ override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
@@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
- case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
@@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorTry: Try[InputStream] = result match {
- case FailureFetchResult(_, e) =>
- Failure(e)
- case SuccessFetchResult(blockId, _, buf) =>
- // There is a chance that createInputStream can fail (e.g. fetching a local file that does
- // not exist, SPARK-4085). In that case, we should propagate the right exception so
- // the scheduler gets a FetchFailedException.
- Try(buf.createInputStream()).map { inputStream =>
- new BufferReleasingInputStream(inputStream, this)
+ result match {
+ case FailureFetchResult(blockId, address, e) =>
+ throwFetchFailedException(blockId, address, e)
+
+ case SuccessFetchResult(blockId, address, _, buf) =>
+ try {
+ (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
+ } catch {
+ case NonFatal(t) =>
+ throwFetchFailedException(blockId, address, t)
}
}
+ }
- (result.blockId, iteratorTry)
+ private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
+ blockId match {
+ case ShuffleBlockId(shufId, mapId, reduceId) =>
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ case _ =>
+ throw new SparkException(
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
+ }
}
}
@@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] sealed trait FetchResult {
val blockId: BlockId
+ val address: BlockManagerId
}
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
+ * @param address BlockManager that the block was fetched from.
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
* @param buf [[ManagedBuffer]] for the content.
*/
- private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ private[storage] case class SuccessFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer)
extends FetchResult {
require(buf != null)
require(size >= 0)
@@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
+ * @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
- private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ private[storage] case class FailureFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable)
extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 06e616220c706..c8356467fab87 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging {
response.setStatus(HttpServletResponse.SC_OK)
val result = servletParams.responder(request)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ // scalastyle:off println
response.getWriter.println(servletParams.extractFn(result))
+ // scalastyle:on println
} else {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
@@ -210,10 +212,16 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = ""): ServerInfo = {
- val collection = new ContextHandlerCollection
- collection.setHandlers(handlers.toArray)
addFilters(handlers, conf)
+ val collection = new ContextHandlerCollection
+ val gzipHandlers = handlers.map { h =>
+ val gzipHandler = new GzipHandler
+ gzipHandler.setHandler(h)
+ gzipHandler
+ }
+ collection.setHandlers(gzipHandlers.toArray)
+
// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
val server = new Server(new InetSocketAddress(hostName, currentPort))
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
new file mode 100644
index 0000000000000..17d7b39c2d951
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.{Node, Unparsed}
+
+/**
+ * A data source that provides data for a page.
+ *
+ * @param pageSize the number of rows in a page
+ */
+private[ui] abstract class PagedDataSource[T](val pageSize: Int) {
+
+ if (pageSize <= 0) {
+ throw new IllegalArgumentException("Page size must be positive")
+ }
+
+ /**
+ * Return the size of all data.
+ */
+ protected def dataSize: Int
+
+ /**
+ * Slice a range of data.
+ */
+ protected def sliceData(from: Int, to: Int): Seq[T]
+
+ /**
+ * Slice the data for this page
+ */
+ def pageData(page: Int): PageData[T] = {
+ val totalPages = (dataSize + pageSize - 1) / pageSize
+ if (page <= 0 || page > totalPages) {
+ throw new IndexOutOfBoundsException(
+ s"Page $page is out of range. Please select a page number between 1 and $totalPages.")
+ }
+ val from = (page - 1) * pageSize
+ val to = dataSize.min(page * pageSize)
+ PageData(totalPages, sliceData(from, to))
+ }
+
+}
+
+/**
+ * The data returned by `PagedDataSource.pageData`, including the page number, the number of total
+ * pages and the data in this page.
+ */
+private[ui] case class PageData[T](totalPage: Int, data: Seq[T])
+
+/**
+ * A paged table that will generate a HTML table for a specified page and also the page navigation.
+ */
+private[ui] trait PagedTable[T] {
+
+ def tableId: String
+
+ def tableCssClass: String
+
+ def dataSource: PagedDataSource[T]
+
+ def headers: Seq[Node]
+
+ def row(t: T): Seq[Node]
+
+ def table(page: Int): Seq[Node] = {
+ val _dataSource = dataSource
+ try {
+ val PageData(totalPages, data) = _dataSource.pageData(page)
+
If the totalPages is 1, the page navigation will be empty
+ *
+ * If the totalPages is more than 1, it will create a page navigation including a group of
+ * page numbers and a form to submit the page number.
+ *
+ *
+ *
+ * Here are some examples of the page navigation:
+ * {{{
+ * << < 11 12 13* 14 15 16 17 18 19 20 > >>
+ *
+ * This is the first group, so "<<" is hidden.
+ * < 1 2* 3 4 5 6 7 8 9 10 > >>
+ *
+ * This is the first group and the first page, so "<<" and "<" are hidden.
+ * 1* 2 3 4 5 6 7 8 9 10 > >>
+ *
+ * Assume totalPages is 19. This is the last group, so ">>" is hidden.
+ * << < 11 12 13* 14 15 16 17 18 19 >
+ *
+ * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden.
+ * << < 11 12 13 14 15 16 17 18 19*
+ *
+ * * means the current page number
+ * << means jumping to the first page of the previous group.
+ * < means jumping to the previous page.
+ * >> means jumping to the first page of the next group.
+ * > means jumping to the next page.
+ * }}}
+ */
+ private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = {
+ if (totalPages == 1) {
+ Nil
+ } else {
+ // A group includes all page numbers will be shown in the page navigation.
+ // The size of group is 10 means there are 10 page numbers will be shown.
+ // The first group is 1 to 10, the second is 2 to 20, and so on
+ val groupSize = 10
+ val firstGroup = 0
+ val lastGroup = (totalPages - 1) / groupSize
+ val currentGroup = (page - 1) / groupSize
+ val startPage = currentGroup * groupSize + 1
+ val endPage = totalPages.min(startPage + groupSize - 1)
+ val pageTags = (startPage to endPage).map { p =>
+ if (p == page) {
+ // The current page should be disabled so that it cannot be clicked.
+
+ }
+ }
+
+ /**
+ * Return a link to jump to a page.
+ */
+ def pageLink(page: Int): String
+
+ /**
+ * Only the implementation knows how to create the url with a page number and the page size, so we
+ * leave this one to the implementation. The implementation should create a JavaScript method that
+ * accepts a page number along with the page size and jumps to the page. The return value is this
+ * method name and its JavaScript codes.
+ */
+ def goButtonJavascriptFunction: (String, String)
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 7898039519201..718aea7e1dc22 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable"
+ val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed"
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
@@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging {
fixedWidth: Boolean = false,
id: Option[String] = None,
headerClasses: Seq[String] = Seq.empty,
- stripeRowsWithCss: Boolean = true): Seq[Node] = {
+ stripeRowsWithCss: Boolean = true,
+ sortable: Boolean = true): Seq[Node] = {
- val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED
+ val listingTableClass = {
+ val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED
+ if (sortable) {
+ _tableClass + " sortable"
+ } else {
+ _tableClass
+ }
+ }
val colWidth = 100.toDouble / headers.size
val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index ba03acdb38cc5..5a8c2914314c2 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator {
def main(args: Array[String]) {
if (args.length < 3) {
+ // scalastyle:off println
println(
- "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
+ "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
"[master] [FIFO|FAIR] [#job set (4 jobs per set)]")
+ // scalastyle:on println
System.exit(1)
}
@@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator {
for ((desc, job) <- jobs) {
new Thread {
override def run() {
+ // scalastyle:off println
try {
setProperties(desc)
job()
@@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator {
} finally {
barrier.release()
}
+ // scalastyle:on println
}
}.start
Thread.sleep(INTER_JOB_WAIT_MS)
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 2c84e4485996e..61449847add3d 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -107,6 +107,25 @@ private[spark] abstract class WebUI(
}
}
+ /**
+ * Add a handler for static content.
+ *
+ * @param resourceBase Root of where to find resources to serve.
+ * @param path Path in UI where to mount the resources.
+ */
+ def addStaticHandler(resourceBase: String, path: String): Unit = {
+ attachHandler(JettyUtils.createStaticHandler(resourceBase, path))
+ }
+
+ /**
+ * Remove a static content handler.
+ *
+ * @param path Path in UI to unmount.
+ */
+ def removeStaticHandler(path: String): Unit = {
+ handlers.find(_.getContextPath() == path).foreach(detachHandler)
+ }
+
/** Initialize all components of the server. */
def initialize()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 2ce670ad02e97..e72547df7254b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -79,6 +79,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
case JobExecutionStatus.SUCCEEDED => "succeeded"
case JobExecutionStatus.FAILED => "failed"
case JobExecutionStatus.RUNNING => "running"
+ case JobExecutionStatus.UNKNOWN => "unknown"
}
// The timeline library treats contents as HTML, so we have to escape them; for the
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 60e3c6343122c..cf04b5e59239b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
import java.util.Date
import javax.servlet.http.HttpServletRequest
@@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
-import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
+import org.apache.spark.ui._
import org.apache.spark.ui.jobs.UIData._
-import org.apache.spark.ui.scope.RDDOperationGraph
import org.apache.spark.util.{Utils, Distribution}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+ import StagePage._
+
private val progressListener = parent.progressListener
private val operationGraphListener = parent.operationGraphListener
@@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val parameterAttempt = request.getParameter("attempt")
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
+ val parameterTaskPage = request.getParameter("task.page")
+ val parameterTaskSortColumn = request.getParameter("task.sort")
+ val parameterTaskSortDesc = request.getParameter("task.desc")
+ val parameterTaskPageSize = request.getParameter("task.pageSize")
+
+ val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
+ val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index")
+ val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false)
+ val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100)
+
// If this is set, expand the dag visualization by default
val expandDagVizParam = request.getParameter("expandDagViz")
val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean
@@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
accumulableRow,
accumulables.values.toSeq)
- val taskHeadersAndCssClasses: Seq[(String, String)] =
- Seq(
- ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""),
- ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""),
- ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY),
- ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
- ("GC Time", ""),
- ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
- ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
- {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
- {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++
- {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
- {if (stageData.hasShuffleRead) {
- Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
- ("Shuffle Read Size / Records", ""),
- ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
- } else {
- Nil
- }} ++
- {if (stageData.hasShuffleWrite) {
- Seq(("Write Time", ""), ("Shuffle Write Size / Records", ""))
- } else {
- Nil
- }} ++
- {if (stageData.hasBytesSpilled) {
- Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
- } else {
- Nil
- }} ++
- Seq(("Errors", ""))
-
- val unzipped = taskHeadersAndCssClasses.unzip
-
val currentTime = System.currentTimeMillis()
- val taskTable = UIUtils.listingTable(
- unzipped._1,
- taskRow(
+ val (taskTable, taskTableHTML) = try {
+ val _taskTable = new TaskPagedTable(
+ UIUtils.prependBaseUri(parent.basePath) +
+ s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
+ tasks,
hasAccumulators,
stageData.hasInput,
stageData.hasOutput,
stageData.hasShuffleRead,
stageData.hasShuffleWrite,
stageData.hasBytesSpilled,
- currentTime),
- tasks,
- headerClasses = unzipped._2)
+ currentTime,
+ pageSize = taskPageSize,
+ sortColumn = taskSortColumn,
+ desc = taskSortDesc
+ )
+ (_taskTable, _taskTable.table(taskPage))
+ } catch {
+ case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
+ (null,
{e.getMessage}
)
+ }
+
+ val jsForScrollingDownToTaskTable =
+
+
+ val taskIdsInPage = if (taskTable == null) Set.empty[Long]
+ else taskTable.dataSource.slicedTaskIds
+
// Excludes tasks which failed and have incomplete metrics
val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
@@ -332,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- getGettingResultTime(info).toDouble
+ getGettingResultTime(info, currentTime).toDouble
}
val gettingResultQuantiles =
@@ -346,7 +353,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- getSchedulerDelay(info, metrics.get).toDouble
+ getSchedulerDelay(info, metrics.get, currentTime).toDouble
}
val schedulerDelayTitle =
Scheduler Delay
@@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
dagViz ++
maybeExpandDagViz ++
showAdditionalMetrics ++
- makeTimeline(stageData.taskData.values.toSeq, currentTime) ++
+ makeTimeline(
+ // Only show the tasks in the table
+ stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)),
+ currentTime) ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}