diff --git a/.rat-excludes b/.rat-excludes
index 0240e81c45ea2..236c2db05367c 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -91,3 +91,5 @@ help/*
html/*
INDEX
.lintr
+gen-java.*
+.*avpr
diff --git a/LICENSE b/LICENSE
index 8672be55eca3e..f9e412cade345 100644
--- a/LICENSE
+++ b/LICENSE
@@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo
(MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org)
(MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/)
(MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt)
- (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org)
+ (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org)
(MIT License) jquery (https://jquery.org/license/)
(MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs)
diff --git a/R/README.md b/R/README.md
index d7d65b4f0eca5..005f56da1670c 100644
--- a/R/README.md
+++ b/R/README.md
@@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R
#### Build Spark
-Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run
+Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run
```
build/mvn -DskipTests -Psparkr package
```
diff --git a/R/install-dev.bat b/R/install-dev.bat
index 008a5c668bc45..f32670b67de96 100644
--- a/R/install-dev.bat
+++ b/R/install-dev.bat
@@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0..
MKDIR %SPARK_HOME%\R\lib
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
+
+rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
+pushd %SPARK_HOME%\R\lib
+%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
+popd
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 1edd551f8d243..4972bb9217072 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib"
mkdir -p $LIB_DIR
-pushd $FWDIR
+pushd $FWDIR > /dev/null
# Generate Rd files if devtools is installed
Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
@@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo
# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
-popd
+# Zip the SparkR package so that it can be distributed to worker nodes on YARN
+cd $LIB_DIR
+jar cfM "$LIB_DIR/sparkr.zip" SparkR
+
+popd > /dev/null
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index efc85bbc4b316..d028821534b1a 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -32,4 +32,3 @@ Collate:
'serialize.R'
'sparkR.R'
'utils.R'
- 'zzz.R'
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 6feabf4189c2d..60702824acb46 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -169,8 +169,8 @@ setMethod("isLocal",
#'}
setMethod("showDF",
signature(x = "DataFrame"),
- function(x, numRows = 20) {
- s <- callJMethod(x@sdf, "showString", numToInt(numRows))
+ function(x, numRows = 20, truncate = TRUE) {
+ s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate)
cat(s)
})
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 89511141d3ef7..d2d096709245d 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
serializedFuncArr,
rdd@env$prev_serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
} else {
@@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
rdd@env$prev_serializedMode,
serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
}
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 9a743a3411533..30978bb50d339 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -86,7 +86,9 @@ infer_type <- function(x) {
createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) {
if (is.data.frame(data)) {
# get the names of columns, they will be put into RDD
- schema <- names(data)
+ if (is.null(schema)) {
+ schema <- names(data)
+ }
n <- nrow(data)
m <- ncol(data)
# get rid of factor type
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 79055b7f18558..fad9d71158c51 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -20,7 +20,8 @@
# @rdname aggregateRDD
# @seealso reduce
# @export
-setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
+setGeneric("aggregateRDD",
+ function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
# @rdname cache-methods
# @export
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 7f902ba8e683e..ebc6ff65e9d0f 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -215,7 +215,6 @@ setMethod("partitionBy",
serializedHashFuncBytes,
getSerializedMode(x),
packageNamesArr,
- as.character(.sparkREnv$libname),
broadcastArr,
callJMethod(jrdd, "classTag"))
@@ -560,8 +559,8 @@ setMethod("join",
# Left outer join two RDDs
#
# @description
-# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
@@ -597,8 +596,8 @@ setMethod("leftOuterJoin",
# Right outer join two RDDs
#
# @description
-# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
@@ -634,8 +633,8 @@ setMethod("rightOuterJoin",
# Full outer join two RDDs
#
# @description
-# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
-# The key types of the two RDDs should be the same.
+# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of
+# the form list(K, V). The key types of the two RDDs should be the same.
#
# @param x An RDD to be joined. Should be an RDD where each element is
# list(K, V).
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 633b869f91784..172335809dec2 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -17,10 +17,6 @@
.sparkREnv <- new.env()
-sparkR.onLoad <- function(libname, pkgname) {
- .sparkREnv$libname <- libname
-}
-
# Utility function that returns TRUE if we have an active connection to the
# backend and FALSE otherwise
connExists <- function(env) {
@@ -80,7 +76,6 @@ sparkR.stop <- function() {
#' @param sparkEnvir Named list of environment variables to set on worker nodes.
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
-#' @param sparkRLibDir The path where R is installed on the worker nodes.
#' @param sparkPackages Character string vector of packages from spark-packages.org
#' @export
#' @examples
@@ -101,15 +96,15 @@ sparkR.init <- function(
sparkEnvir = list(),
sparkExecutorEnv = list(),
sparkJars = "",
- sparkRLibDir = "",
sparkPackages = "") {
if (exists(".sparkRjsc", envir = .sparkREnv)) {
- cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
+ cat(paste("Re-using existing Spark Context.",
+ "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n"))
return(get(".sparkRjsc", envir = .sparkREnv))
}
- sparkMem <- Sys.getenv("SPARK_MEM", "512m")
+ sparkMem <- Sys.getenv("SPARK_MEM", "1024m")
jars <- suppressWarnings(normalizePath(as.character(sparkJars)))
# Classpath separator is ";" on Windows
@@ -169,10 +164,6 @@ sparkR.init <- function(
sparkHome <- normalizePath(sparkHome)
}
- if (nchar(sparkRLibDir) != 0) {
- .sparkREnv$libname <- sparkRLibDir
- }
-
sparkEnvirMap <- new.env()
for (varname in names(sparkEnvir)) {
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
@@ -180,14 +171,16 @@ sparkR.init <- function(
sparkExecutorEnvMap <- new.env()
if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
- sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
+ sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
+ paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
}
for (varname in names(sparkExecutorEnv)) {
sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
}
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
- localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
+ localJarPaths <- sapply(nonEmptyJars,
+ function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
# Set the start time to identify jobjs
# Seconds resolution is good enough for this purpose, so use ints
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 13cec0f712fb4..ea629a64f7158 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY",
"MEMORY_ONLY_SER_2",
"OFF_HEAP")) {
match.arg(newLevel)
+ storageLevelClass <- "org.apache.spark.storage.StorageLevel"
storageLevel <- switch(newLevel,
- "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"),
- "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"),
- "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"),
- "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"),
- "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"),
- "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"),
- "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"),
- "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"),
- "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"),
- "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"),
- "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP"))
+ "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"),
+ "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"),
+ "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"),
+ "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"),
+ "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass,
+ "MEMORY_AND_DISK_SER"),
+ "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass,
+ "MEMORY_AND_DISK_SER_2"),
+ "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"),
+ "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"),
+ "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"),
+ "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"),
+ "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP"))
}
# Utility function for functions where an argument needs to be integer but we want to allow
@@ -545,9 +548,11 @@ mergePartitions <- function(rdd, zip) {
lengthOfKeys <- part[[len - lengthOfValues]]
stopifnot(len == lengthOfKeys + lengthOfValues)
- # For zip operation, check if corresponding partitions of both RDDs have the same number of elements.
+ # For zip operation, check if corresponding partitions
+ # of both RDDs have the same number of elements.
if (zip && lengthOfKeys != lengthOfValues) {
- stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
+ stop(paste("Can only zip RDDs with same number of elements",
+ "in each pair of corresponding partitions."))
}
if (lengthOfKeys > 1) {
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
index 8fe711b622086..2a8a8213d0849 100644
--- a/R/pkg/inst/profile/general.R
+++ b/R/pkg/inst/profile/general.R
@@ -16,7 +16,7 @@
#
.First <- function() {
- home <- Sys.getenv("SPARK_HOME")
- .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
+ packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
+ .libPaths(c(packageDir, .libPaths()))
Sys.setenv(NOAWT=1)
}
diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R
index 4db7266abc8e2..ccaea18ecab2a 100644
--- a/R/pkg/inst/tests/test_binaryFile.R
+++ b/R/pkg/inst/tests/test_binaryFile.R
@@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", {
saveAsObjectFile(rdd2, fileName2)
rdd <- objectFile(sc, c(fileName1, fileName2))
- expect_true(count(rdd) == 2)
+ expect_equal(count(rdd), 2)
unlink(fileName1, recursive = TRUE)
unlink(fileName2, recursive = TRUE)
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
index a1e354e567be5..3be8c65a6c1a0 100644
--- a/R/pkg/inst/tests/test_binary_function.R
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -38,13 +38,13 @@ test_that("union on two RDDs", {
union.rdd <- unionRDD(rdd, text.rdd)
actual <- collect(union.rdd)
expect_equal(actual, c(as.list(nums), mockFile))
- expect_true(getSerializedMode(union.rdd) == "byte")
+ expect_equal(getSerializedMode(union.rdd), "byte")
rdd<- map(text.rdd, function(x) {x})
union.rdd <- unionRDD(rdd, text.rdd)
actual <- collect(union.rdd)
expect_equal(actual, as.list(c(mockFile, mockFile)))
- expect_true(getSerializedMode(union.rdd) == "byte")
+ expect_equal(getSerializedMode(union.rdd), "byte")
unlink(fileName)
})
diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R
index 8bc693be20c3c..cc1faeabffe30 100644
--- a/R/pkg/inst/tests/test_includeJAR.R
+++ b/R/pkg/inst/tests/test_includeJAR.R
@@ -18,8 +18,8 @@ context("include an external JAR in SparkContext")
runScript <- function() {
sparkHome <- Sys.getenv("SPARK_HOME")
- jarPath <- paste("--jars",
- shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar")))
+ sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"
+ jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath)))
scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R")
submitPath <- file.path(sparkHome, "bin/spark-submit")
res <- system2(command = submitPath,
@@ -31,7 +31,7 @@ runScript <- function() {
test_that("sparkJars tag in SparkContext", {
testOutput <- runScript()
helloTest <- testOutput[1]
- expect_true(helloTest == "Hello, Dave")
+ expect_equal(helloTest, "Hello, Dave")
basicFunction <- testOutput[2]
- expect_true(basicFunction == 4L)
+ expect_equal(basicFunction, "4")
})
diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R
index fff028657db37..2552127cc547f 100644
--- a/R/pkg/inst/tests/test_parallelize_collect.R
+++ b/R/pkg/inst/tests/test_parallelize_collect.R
@@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", {
strListRDD2)
for (rdd in rdds) {
- expect_true(inherits(rdd, "RDD"))
+ expect_is(rdd, "RDD")
expect_true(.hasSlot(rdd, "jrdd")
&& inherits(rdd@jrdd, "jobj")
&& isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD"))
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index 4fe653856756e..b79692873cec3 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", {
})
test_that("first on RDD", {
- expect_true(first(rdd) == 1)
+ expect_equal(first(rdd), 1)
newrdd <- lapply(rdd, function(x) x + 1)
- expect_true(first(newrdd) == 2)
+ expect_equal(first(newrdd), 2)
})
test_that("count and length on RDD", {
@@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3)))
rdd2 <- parallelize(sc, list(list(1,1), list(2,4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
- expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL)))
+ expected <- list(list(1, list(2, 1)), list(1, list(3, 1)),
+ list(2, list(NULL, 4)), list(3, list(3, NULL)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1)))
rdd2 <- parallelize(sc, list(list("a",1), list("b",4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
- expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL)))
+ expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)),
+ list("a", list(3, 1)), list("c", list(1, NULL)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
@@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", {
rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
- sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
+ sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)),
+ list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
- sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
+ sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)),
+ list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
})
test_that("sortByKey() on pairwise RDDs", {
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 6a08f894313c4..b0ea38854304e 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -61,7 +61,7 @@ test_that("infer types", {
expect_equal(infer_type(list(1L, 2L)),
list(type = 'array', elementType = "integer", containsNull = TRUE))
testStruct <- infer_type(list(a = 1L, b = "2"))
- expect_true(class(testStruct) == "structType")
+ expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE)
e <- new.env()
@@ -73,39 +73,39 @@ test_that("infer types", {
test_that("structType and structField", {
testField <- structField("a", "string")
- expect_true(inherits(testField, "structField"))
- expect_true(testField$name() == "a")
+ expect_is(testField, "structField")
+ expect_equal(testField$name(), "a")
expect_true(testField$nullable())
testSchema <- structType(testField, structField("b", "integer"))
- expect_true(inherits(testSchema, "structType"))
- expect_true(inherits(testSchema$fields()[[2]], "structField"))
- expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType")
+ expect_is(testSchema, "structType")
+ expect_is(testSchema$fields()[[2]], "structField")
+ expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType")
})
test_that("create DataFrame from RDD", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- createDataFrame(sqlContext, rdd, list("a", "b"))
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 10)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
df <- createDataFrame(sqlContext, rdd)
- expect_true(inherits(df, "DataFrame"))
+ expect_is(df, "DataFrame")
expect_equal(columns(df), c("_1", "_2"))
schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
structField(x = "b", type = "string", nullable = TRUE))
df <- createDataFrame(sqlContext, rdd, schema)
- expect_true(inherits(df, "DataFrame"))
+ expect_is(df, "DataFrame")
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) })
df <- createDataFrame(sqlContext, rdd)
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 10)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
})
@@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", {
test_that("toDF", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- toDF(rdd, list("a", "b"))
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 10)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
df <- toDF(rdd)
- expect_true(inherits(df, "DataFrame"))
+ expect_is(df, "DataFrame")
expect_equal(columns(df), c("_1", "_2"))
schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
structField(x = "b", type = "string", nullable = TRUE))
df <- toDF(rdd, schema)
- expect_true(inherits(df, "DataFrame"))
+ expect_is(df, "DataFrame")
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) })
df <- toDF(rdd)
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 10)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
})
@@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", {
test_that("jsonFile() on a local file returns a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 3)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
})
test_that("jsonRDD() on a RDD with json string", {
rdd <- parallelize(sc, mockLines)
- expect_true(count(rdd) == 3)
+ expect_equal(count(rdd), 3)
df <- jsonRDD(sqlContext, rdd)
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 3)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
rdd2 <- flatMap(rdd, function(x) c(x, x))
df <- jsonRDD(sqlContext, rdd2)
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 6)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 6)
})
test_that("test cache, uncache and clearCache", {
@@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", {
test_that("test tableNames and tables", {
df <- jsonFile(sqlContext, jsonPath)
registerTempTable(df, "table1")
- expect_true(length(tableNames(sqlContext)) == 1)
+ expect_equal(length(tableNames(sqlContext)), 1)
df <- tables(sqlContext)
- expect_true(count(df) == 1)
+ expect_equal(count(df), 1)
dropTempTable(sqlContext, "table1")
})
@@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in
df <- jsonFile(sqlContext, jsonPath)
registerTempTable(df, "table1")
newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'")
- expect_true(inherits(newdf, "DataFrame"))
- expect_true(count(newdf) == 1)
+ expect_is(newdf, "DataFrame")
+ expect_equal(count(newdf), 1)
dropTempTable(sqlContext, "table1")
})
@@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", {
registerTempTable(dfParquet, "table1")
insertInto(dfParquet2, "table1")
- expect_true(count(sql(sqlContext, "select * from table1")) == 5)
- expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael")
+ expect_equal(count(sql(sqlContext, "select * from table1")), 5)
+ expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael")
dropTempTable(sqlContext, "table1")
registerTempTable(dfParquet, "table1")
insertInto(dfParquet2, "table1", overwrite = TRUE)
- expect_true(count(sql(sqlContext, "select * from table1")) == 2)
- expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob")
+ expect_equal(count(sql(sqlContext, "select * from table1")), 2)
+ expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob")
dropTempTable(sqlContext, "table1")
})
@@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
registerTempTable(df, "table1")
tabledf <- table(sqlContext, "table1")
- expect_true(inherits(tabledf, "DataFrame"))
- expect_true(count(tabledf) == 3)
+ expect_is(tabledf, "DataFrame")
+ expect_equal(count(tabledf), 3)
dropTempTable(sqlContext, "table1")
})
test_that("toRDD() returns an RRDD", {
df <- jsonFile(sqlContext, jsonPath)
testRDD <- toRDD(df)
- expect_true(inherits(testRDD, "RDD"))
- expect_true(count(testRDD) == 3)
+ expect_is(testRDD, "RDD")
+ expect_equal(count(testRDD), 3)
})
test_that("union on two RDDs created from DataFrames returns an RRDD", {
@@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", {
RDD1 <- toRDD(df)
RDD2 <- toRDD(df)
unioned <- unionRDD(RDD1, RDD2)
- expect_true(inherits(unioned, "RDD"))
- expect_true(SparkR:::getSerializedMode(unioned) == "byte")
- expect_true(collect(unioned)[[2]]$name == "Andy")
+ expect_is(unioned, "RDD")
+ expect_equal(SparkR:::getSerializedMode(unioned), "byte")
+ expect_equal(collect(unioned)[[2]]$name, "Andy")
})
test_that("union on mixed serialization types correctly returns a byte RRDD", {
@@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", {
dfRDD <- toRDD(df)
unionByte <- unionRDD(rdd, dfRDD)
- expect_true(inherits(unionByte, "RDD"))
- expect_true(SparkR:::getSerializedMode(unionByte) == "byte")
- expect_true(collect(unionByte)[[1]] == 1)
- expect_true(collect(unionByte)[[12]]$name == "Andy")
+ expect_is(unionByte, "RDD")
+ expect_equal(SparkR:::getSerializedMode(unionByte), "byte")
+ expect_equal(collect(unionByte)[[1]], 1)
+ expect_equal(collect(unionByte)[[12]]$name, "Andy")
unionString <- unionRDD(textRDD, dfRDD)
- expect_true(inherits(unionString, "RDD"))
- expect_true(SparkR:::getSerializedMode(unionString) == "byte")
- expect_true(collect(unionString)[[1]] == "Michael")
- expect_true(collect(unionString)[[5]]$name == "Andy")
+ expect_is(unionString, "RDD")
+ expect_equal(SparkR:::getSerializedMode(unionString), "byte")
+ expect_equal(collect(unionString)[[1]], "Michael")
+ expect_equal(collect(unionString)[[5]]$name, "Andy")
})
test_that("objectFile() works with row serialization", {
@@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", {
saveAsObjectFile(coalesce(dfRDD, 1L), objectPath)
objectIn <- objectFile(sc, objectPath)
- expect_true(inherits(objectIn, "RDD"))
+ expect_is(objectIn, "RDD")
expect_equal(SparkR:::getSerializedMode(objectIn), "byte")
expect_equal(collect(objectIn)[[2]]$age, 30)
})
@@ -363,35 +363,35 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", {
row$newCol <- row$age + 5
row
})
- expect_true(inherits(testRDD, "RDD"))
+ expect_is(testRDD, "RDD")
collected <- collect(testRDD)
- expect_true(collected[[1]]$name == "Michael")
- expect_true(collected[[2]]$newCol == "35")
+ expect_equal(collected[[1]]$name, "Michael")
+ expect_equal(collected[[2]]$newCol, 35)
})
test_that("collect() returns a data.frame", {
df <- jsonFile(sqlContext, jsonPath)
rdf <- collect(df)
expect_true(is.data.frame(rdf))
- expect_true(names(rdf)[1] == "age")
- expect_true(nrow(rdf) == 3)
- expect_true(ncol(rdf) == 2)
+ expect_equal(names(rdf)[1], "age")
+ expect_equal(nrow(rdf), 3)
+ expect_equal(ncol(rdf), 2)
})
test_that("limit() returns DataFrame with the correct number of rows", {
df <- jsonFile(sqlContext, jsonPath)
dfLimited <- limit(df, 2)
- expect_true(inherits(dfLimited, "DataFrame"))
- expect_true(count(dfLimited) == 2)
+ expect_is(dfLimited, "DataFrame")
+ expect_equal(count(dfLimited), 2)
})
test_that("collect() and take() on a DataFrame return the same number of rows and columns", {
df <- jsonFile(sqlContext, jsonPath)
- expect_true(nrow(collect(df)) == nrow(take(df, 10)))
- expect_true(ncol(collect(df)) == ncol(take(df, 10)))
+ expect_equal(nrow(collect(df)), nrow(take(df, 10)))
+ expect_equal(ncol(collect(df)), ncol(take(df, 10)))
})
-test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", {
+test_that("multiple pipeline transformations result in an RDD with the correct values", {
df <- jsonFile(sqlContext, jsonPath)
first <- lapply(df, function(row) {
row$age <- row$age + 5
@@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in
row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE
row
})
- expect_true(inherits(second, "RDD"))
- expect_true(count(second) == 3)
- expect_true(collect(second)[[2]]$age == 35)
+ expect_is(second, "RDD")
+ expect_equal(count(second), 3)
+ expect_equal(collect(second)[[2]]$age, 35)
expect_true(collect(second)[[2]]$testCol)
expect_false(collect(second)[[3]]$testCol)
})
@@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", {
test_that("schema(), dtypes(), columns(), names() return the correct values/format", {
df <- jsonFile(sqlContext, jsonPath)
testSchema <- schema(df)
- expect_true(length(testSchema$fields()) == 2)
- expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType")
- expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string")
- expect_true(testSchema$fields()[[1]]$name() == "age")
+ expect_equal(length(testSchema$fields()), 2)
+ expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType")
+ expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string")
+ expect_equal(testSchema$fields()[[1]]$name(), "age")
testTypes <- dtypes(df)
- expect_true(length(testTypes[[1]]) == 2)
- expect_true(testTypes[[1]][1] == "age")
+ expect_equal(length(testTypes[[1]]), 2)
+ expect_equal(testTypes[[1]][1], "age")
testCols <- columns(df)
- expect_true(length(testCols) == 2)
- expect_true(testCols[2] == "name")
+ expect_equal(length(testCols), 2)
+ expect_equal(testCols[2], "name")
testNames <- names(df)
- expect_true(length(testNames) == 2)
- expect_true(testNames[2] == "name")
+ expect_equal(length(testNames), 2)
+ expect_equal(testNames[2], "name")
})
test_that("head() and first() return the correct data", {
df <- jsonFile(sqlContext, jsonPath)
testHead <- head(df)
- expect_true(nrow(testHead) == 3)
- expect_true(ncol(testHead) == 2)
+ expect_equal(nrow(testHead), 3)
+ expect_equal(ncol(testHead), 2)
testHead2 <- head(df, 2)
- expect_true(nrow(testHead2) == 2)
- expect_true(ncol(testHead2) == 2)
+ expect_equal(nrow(testHead2), 2)
+ expect_equal(ncol(testHead2), 2)
testFirst <- first(df)
- expect_true(nrow(testFirst) == 1)
+ expect_equal(nrow(testFirst), 1)
})
test_that("distinct() on DataFrames", {
@@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", {
df <- jsonFile(sqlContext, jsonPathWithDup)
uniques <- distinct(df)
- expect_true(inherits(uniques, "DataFrame"))
- expect_true(count(uniques) == 3)
+ expect_is(uniques, "DataFrame")
+ expect_equal(count(uniques), 3)
})
test_that("sample on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
sampled <- sample(df, FALSE, 1.0)
expect_equal(nrow(collect(sampled)), count(df))
- expect_true(inherits(sampled, "DataFrame"))
+ expect_is(sampled, "DataFrame")
sampled2 <- sample(df, FALSE, 0.1)
expect_true(count(sampled2) < 3)
@@ -491,15 +491,15 @@ test_that("sample on a DataFrame", {
test_that("select operators", {
df <- select(jsonFile(sqlContext, jsonPath), "name", "age")
- expect_true(inherits(df$name, "Column"))
- expect_true(inherits(df[[2]], "Column"))
- expect_true(inherits(df[["age"]], "Column"))
+ expect_is(df$name, "Column")
+ expect_is(df[[2]], "Column")
+ expect_is(df[["age"]], "Column")
- expect_true(inherits(df[,1], "DataFrame"))
+ expect_is(df[,1], "DataFrame")
expect_equal(columns(df[,1]), c("name"))
expect_equal(columns(df[,"age"]), c("age"))
df2 <- df[,c("age", "name")]
- expect_true(inherits(df2, "DataFrame"))
+ expect_is(df2, "DataFrame")
expect_equal(columns(df2), c("age", "name"))
df$age2 <- df$age
@@ -518,50 +518,50 @@ test_that("select operators", {
test_that("select with column", {
df <- jsonFile(sqlContext, jsonPath)
df1 <- select(df, "name")
- expect_true(columns(df1) == c("name"))
- expect_true(count(df1) == 3)
+ expect_equal(columns(df1), c("name"))
+ expect_equal(count(df1), 3)
df2 <- select(df, df$age)
- expect_true(columns(df2) == c("age"))
- expect_true(count(df2) == 3)
+ expect_equal(columns(df2), c("age"))
+ expect_equal(count(df2), 3)
})
test_that("selectExpr() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
selected <- selectExpr(df, "age * 2")
- expect_true(names(selected) == "(age * 2)")
+ expect_equal(names(selected), "(age * 2)")
expect_equal(collect(selected), collect(select(df, df$age * 2L)))
selected2 <- selectExpr(df, "name as newName", "abs(age) as age")
expect_equal(names(selected2), c("newName", "age"))
- expect_true(count(selected2) == 3)
+ expect_equal(count(selected2), 3)
})
test_that("column calculation", {
df <- jsonFile(sqlContext, jsonPath)
d <- collect(select(df, alias(df$age + 1, "age2")))
- expect_true(names(d) == c("age2"))
+ expect_equal(names(d), c("age2"))
df2 <- select(df, lower(df$name), abs(df$age))
- expect_true(inherits(df2, "DataFrame"))
- expect_true(count(df2) == 3)
+ expect_is(df2, "DataFrame")
+ expect_equal(count(df2), 3)
})
test_that("read.df() from json file", {
df <- read.df(sqlContext, jsonPath, "json")
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 3)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
# Check if we can apply a user defined schema
schema <- structType(structField("name", type = "string"),
structField("age", type = "double"))
df1 <- read.df(sqlContext, jsonPath, "json", schema)
- expect_true(inherits(df1, "DataFrame"))
+ expect_is(df1, "DataFrame")
expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
# Run the same with loadDF
df2 <- loadDF(sqlContext, jsonPath, "json", schema)
- expect_true(inherits(df2, "DataFrame"))
+ expect_is(df2, "DataFrame")
expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
})
@@ -569,8 +569,8 @@ test_that("write.df() as parquet file", {
df <- read.df(sqlContext, jsonPath, "json")
write.df(df, parquetPath, "parquet", mode="overwrite")
df2 <- read.df(sqlContext, parquetPath, "parquet")
- expect_true(inherits(df2, "DataFrame"))
- expect_true(count(df2) == 3)
+ expect_is(df2, "DataFrame")
+ expect_equal(count(df2), 3)
})
test_that("test HiveContext", {
@@ -580,17 +580,17 @@ test_that("test HiveContext", {
skip("Hive is not build with SparkSQL, skipped")
})
df <- createExternalTable(hiveCtx, "json", jsonPath, "json")
- expect_true(inherits(df, "DataFrame"))
- expect_true(count(df) == 3)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
df2 <- sql(hiveCtx, "select * from json")
- expect_true(inherits(df2, "DataFrame"))
- expect_true(count(df2) == 3)
+ expect_is(df2, "DataFrame")
+ expect_equal(count(df2), 3)
jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
saveAsTable(df, "json", "json", "append", path = jsonPath2)
df3 <- sql(hiveCtx, "select * from json")
- expect_true(inherits(df3, "DataFrame"))
- expect_true(count(df3) == 6)
+ expect_is(df3, "DataFrame")
+ expect_equal(count(df3), 6)
})
test_that("column operators", {
@@ -643,65 +643,65 @@ test_that("string operators", {
test_that("group by", {
df <- jsonFile(sqlContext, jsonPath)
df1 <- agg(df, name = "max", age = "sum")
- expect_true(1 == count(df1))
+ expect_equal(1, count(df1))
df1 <- agg(df, age2 = max(df$age))
- expect_true(1 == count(df1))
+ expect_equal(1, count(df1))
expect_equal(columns(df1), c("age2"))
gd <- groupBy(df, "name")
- expect_true(inherits(gd, "GroupedData"))
+ expect_is(gd, "GroupedData")
df2 <- count(gd)
- expect_true(inherits(df2, "DataFrame"))
- expect_true(3 == count(df2))
+ expect_is(df2, "DataFrame")
+ expect_equal(3, count(df2))
# Also test group_by, summarize, mean
gd1 <- group_by(df, "name")
- expect_true(inherits(gd1, "GroupedData"))
+ expect_is(gd1, "GroupedData")
df_summarized <- summarize(gd, mean_age = mean(df$age))
- expect_true(inherits(df_summarized, "DataFrame"))
- expect_true(3 == count(df_summarized))
+ expect_is(df_summarized, "DataFrame")
+ expect_equal(3, count(df_summarized))
df3 <- agg(gd, age = "sum")
- expect_true(inherits(df3, "DataFrame"))
- expect_true(3 == count(df3))
+ expect_is(df3, "DataFrame")
+ expect_equal(3, count(df3))
df3 <- agg(gd, age = sum(df$age))
- expect_true(inherits(df3, "DataFrame"))
- expect_true(3 == count(df3))
+ expect_is(df3, "DataFrame")
+ expect_equal(3, count(df3))
expect_equal(columns(df3), c("name", "age"))
df4 <- sum(gd, "age")
- expect_true(inherits(df4, "DataFrame"))
- expect_true(3 == count(df4))
- expect_true(3 == count(mean(gd, "age")))
- expect_true(3 == count(max(gd, "age")))
+ expect_is(df4, "DataFrame")
+ expect_equal(3, count(df4))
+ expect_equal(3, count(mean(gd, "age")))
+ expect_equal(3, count(max(gd, "age")))
})
test_that("arrange() and orderBy() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
sorted <- arrange(df, df$age)
- expect_true(collect(sorted)[1,2] == "Michael")
+ expect_equal(collect(sorted)[1,2], "Michael")
sorted2 <- arrange(df, "name")
- expect_true(collect(sorted2)[2,"age"] == 19)
+ expect_equal(collect(sorted2)[2,"age"], 19)
sorted3 <- orderBy(df, asc(df$age))
expect_true(is.na(first(sorted3)$age))
- expect_true(collect(sorted3)[2, "age"] == 19)
+ expect_equal(collect(sorted3)[2, "age"], 19)
sorted4 <- orderBy(df, desc(df$name))
- expect_true(first(sorted4)$name == "Michael")
- expect_true(collect(sorted4)[3,"name"] == "Andy")
+ expect_equal(first(sorted4)$name, "Michael")
+ expect_equal(collect(sorted4)[3,"name"], "Andy")
})
test_that("filter() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
filtered <- filter(df, "age > 20")
- expect_true(count(filtered) == 1)
- expect_true(collect(filtered)$name == "Andy")
+ expect_equal(count(filtered), 1)
+ expect_equal(collect(filtered)$name, "Andy")
filtered2 <- where(df, df$name != "Michael")
- expect_true(count(filtered2) == 2)
- expect_true(collect(filtered2)$age[2] == 19)
+ expect_equal(count(filtered2), 2)
+ expect_equal(collect(filtered2)$age[2], 19)
# test suites for %in%
filtered3 <- filter(df, "age in (19)")
@@ -727,36 +727,43 @@ test_that("join() on a DataFrame", {
joined <- join(df, df2)
expect_equal(names(joined), c("age", "name", "name", "test"))
- expect_true(count(joined) == 12)
+ expect_equal(count(joined), 12)
joined2 <- join(df, df2, df$name == df2$name)
expect_equal(names(joined2), c("age", "name", "name", "test"))
- expect_true(count(joined2) == 3)
+ expect_equal(count(joined2), 3)
joined3 <- join(df, df2, df$name == df2$name, "right_outer")
expect_equal(names(joined3), c("age", "name", "name", "test"))
- expect_true(count(joined3) == 4)
+ expect_equal(count(joined3), 4)
expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2]))
joined4 <- select(join(df, df2, df$name == df2$name, "outer"),
alias(df$age + 5, "newAge"), df$name, df2$test)
expect_equal(names(joined4), c("newAge", "name", "test"))
- expect_true(count(joined4) == 4)
+ expect_equal(count(joined4), 4)
expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24)
})
test_that("toJSON() returns an RDD of the correct values", {
df <- jsonFile(sqlContext, jsonPath)
testRDD <- toJSON(df)
- expect_true(inherits(testRDD, "RDD"))
- expect_true(SparkR:::getSerializedMode(testRDD) == "string")
+ expect_is(testRDD, "RDD")
+ expect_equal(SparkR:::getSerializedMode(testRDD), "string")
expect_equal(collect(testRDD)[[1]], mockLines[1])
})
test_that("showDF()", {
df <- jsonFile(sqlContext, jsonPath)
s <- capture.output(showDF(df))
- expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n")
+ expected <- paste("+----+-------+\n",
+ "| age| name|\n",
+ "+----+-------+\n",
+ "|null|Michael|\n",
+ "| 30| Andy|\n",
+ "| 19| Justin|\n",
+ "+----+-------+\n", sep="")
+ expect_output(s , expected)
})
test_that("isLocal()", {
@@ -775,50 +782,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
df2 <- read.df(sqlContext, jsonPath2, "json")
unioned <- arrange(unionAll(df, df2), df$age)
- expect_true(inherits(unioned, "DataFrame"))
- expect_true(count(unioned) == 6)
- expect_true(first(unioned)$name == "Michael")
+ expect_is(unioned, "DataFrame")
+ expect_equal(count(unioned), 6)
+ expect_equal(first(unioned)$name, "Michael")
excepted <- arrange(except(df, df2), desc(df$age))
- expect_true(inherits(unioned, "DataFrame"))
- expect_true(count(excepted) == 2)
- expect_true(first(excepted)$name == "Justin")
+ expect_is(unioned, "DataFrame")
+ expect_equal(count(excepted), 2)
+ expect_equal(first(excepted)$name, "Justin")
intersected <- arrange(intersect(df, df2), df$age)
- expect_true(inherits(unioned, "DataFrame"))
- expect_true(count(intersected) == 1)
- expect_true(first(intersected)$name == "Andy")
+ expect_is(unioned, "DataFrame")
+ expect_equal(count(intersected), 1)
+ expect_equal(first(intersected)$name, "Andy")
})
test_that("withColumn() and withColumnRenamed()", {
df <- jsonFile(sqlContext, jsonPath)
newDF <- withColumn(df, "newAge", df$age + 2)
- expect_true(length(columns(newDF)) == 3)
- expect_true(columns(newDF)[3] == "newAge")
- expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32)
newDF2 <- withColumnRenamed(df, "age", "newerAge")
- expect_true(length(columns(newDF2)) == 2)
- expect_true(columns(newDF2)[1] == "newerAge")
+ expect_equal(length(columns(newDF2)), 2)
+ expect_equal(columns(newDF2)[1], "newerAge")
})
test_that("mutate() and rename()", {
df <- jsonFile(sqlContext, jsonPath)
newDF <- mutate(df, newAge = df$age + 2)
- expect_true(length(columns(newDF)) == 3)
- expect_true(columns(newDF)[3] == "newAge")
- expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
+ expect_equal(length(columns(newDF)), 3)
+ expect_equal(columns(newDF)[3], "newAge")
+ expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32)
newDF2 <- rename(df, newerAge = df$age)
- expect_true(length(columns(newDF2)) == 2)
- expect_true(columns(newDF2)[1] == "newerAge")
+ expect_equal(length(columns(newDF2)), 2)
+ expect_equal(columns(newDF2)[1], "newerAge")
})
test_that("write.df() on DataFrame and works with parquetFile", {
df <- jsonFile(sqlContext, jsonPath)
write.df(df, parquetPath, "parquet", mode="overwrite")
parquetDF <- parquetFile(sqlContext, parquetPath)
- expect_true(inherits(parquetDF, "DataFrame"))
+ expect_is(parquetDF, "DataFrame")
expect_equal(count(df), count(parquetDF))
})
@@ -828,8 +835,8 @@ test_that("parquetFile works with multiple input paths", {
parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
write.df(df, parquetPath2, "parquet", mode="overwrite")
parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2)
- expect_true(inherits(parquetDF, "DataFrame"))
- expect_true(count(parquetDF) == count(df)*2)
+ expect_is(parquetDF, "DataFrame")
+ expect_equal(count(parquetDF), count(df)*2)
})
test_that("describe() on a DataFrame", {
@@ -851,58 +858,58 @@ test_that("dropna() on a DataFrame", {
expected <- rows[!is.na(rows$name),]
actual <- collect(dropna(df, cols = "name"))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age),]
actual <- collect(dropna(df, cols = "age"))
row.names(expected) <- row.names(actual)
# identical on two dataframes does not work here. Don't know why.
# use identical on all columns as a workaround.
- expect_true(identical(expected$age, actual$age))
- expect_true(identical(expected$height, actual$height))
- expect_true(identical(expected$name, actual$name))
+ expect_identical(expected$age, actual$age)
+ expect_identical(expected$height, actual$height)
+ expect_identical(expected$name, actual$name)
expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
actual <- collect(dropna(df, cols = c("age", "height")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
actual <- collect(dropna(df))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
# drop with how
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
actual <- collect(dropna(df))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),]
actual <- collect(dropna(df, "all"))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
actual <- collect(dropna(df, "any"))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
actual <- collect(dropna(df, "any", cols = c("age", "height")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[!is.na(rows$age) | !is.na(rows$height),]
actual <- collect(dropna(df, "all", cols = c("age", "height")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
# drop with threshold
expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,]
actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows[as.integer(!is.na(rows$age)) +
as.integer(!is.na(rows$height)) +
as.integer(!is.na(rows$name)) >= 3,]
actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
})
test_that("fillna() on a DataFrame", {
@@ -915,22 +922,22 @@ test_that("fillna() on a DataFrame", {
expected$age[is.na(expected$age)] <- 50
expected$height[is.na(expected$height)] <- 50.6
actual <- collect(fillna(df, 50.6))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows
expected$name[is.na(expected$name)] <- "unknown"
actual <- collect(fillna(df, "unknown"))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows
expected$age[is.na(expected$age)] <- 50
actual <- collect(fillna(df, 50.6, "age"))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
expected <- rows
expected$name[is.na(expected$name)] <- "unknown"
actual <- collect(fillna(df, "unknown", c("age", "name")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
# fill with named list
@@ -939,7 +946,7 @@ test_that("fillna() on a DataFrame", {
expected$height[is.na(expected$height)] <- 50.6
expected$name[is.na(expected$name)] <- "unknown"
actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown")))
- expect_true(identical(expected, actual))
+ expect_identical(expected, actual)
})
unlink(parquetPath)
diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R
index c5eb417b40159..c2c724cdc762f 100644
--- a/R/pkg/inst/tests/test_take.R
+++ b/R/pkg/inst/tests/test_take.R
@@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order",
expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3)))
expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1)))
- expect_true(length(take(strListRDD, 0)) == 0)
- expect_true(length(take(strVectorRDD, 0)) == 0)
- expect_true(length(take(numListRDD, 0)) == 0)
- expect_true(length(take(numVectorRDD, 0)) == 0)
+ expect_equal(length(take(strListRDD, 0)), 0)
+ expect_equal(length(take(strVectorRDD, 0)), 0)
+ expect_equal(length(take(numListRDD, 0)), 0)
+ expect_equal(length(take(numVectorRDD, 0)), 0)
})
diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R
index 092ad9dc10c2e..58318dfef71ab 100644
--- a/R/pkg/inst/tests/test_textFile.R
+++ b/R/pkg/inst/tests/test_textFile.R
@@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", {
writeLines(mockFile, fileName)
rdd <- textFile(sc, fileName)
- expect_true(inherits(rdd, "RDD"))
+ expect_is(rdd, "RDD")
expect_true(count(rdd) > 0)
- expect_true(count(rdd) == 2)
+ expect_equal(count(rdd), 2)
unlink(fileName)
})
@@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", {
writeLines("Spark is awesome.", fileName2)
rdd <- textFile(sc, c(fileName1, fileName2))
- expect_true(count(rdd) == 2)
+ expect_equal(count(rdd), 2)
unlink(fileName1)
unlink(fileName2)
diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R
index 15030e6f1d77e..aa0d2a66b9082 100644
--- a/R/pkg/inst/tests/test_utils.R
+++ b/R/pkg/inst/tests/test_utils.R
@@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", {
writeLines(mockFile, fileName)
text.rdd <- textFile(sc, fileName)
- expect_true(getSerializedMode(text.rdd) == "string")
+ expect_equal(getSerializedMode(text.rdd), "string")
ser.rdd <- serializeToBytes(text.rdd)
expect_equal(collect(ser.rdd), as.list(mockFile))
- expect_true(getSerializedMode(ser.rdd) == "byte")
+ expect_equal(getSerializedMode(ser.rdd), "byte")
unlink(fileName)
})
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 43c4288912b18..192d3ae091134 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -22,7 +22,7 @@
# - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2)
# - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1).
# - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G)
-# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)
+# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G)
# - SPARK_YARN_APP_NAME, The name of your application (Default: Spark)
# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’)
# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job.
diff --git a/core/pom.xml b/core/pom.xml
index 565437c4861a4..558cc3fb9f2f3 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -69,16 +69,6 @@
org.apache.hadoophadoop-client
-
-
- javax.servlet
- servlet-api
-
-
- org.codehaus.jackson
- jackson-mapper-asl
-
- org.apache.spark
@@ -353,28 +343,28 @@
test
- org.mockito
- mockito-core
+ org.hamcrest
+ hamcrest-coretest
- org.scalacheck
- scalacheck_${scala.binary.version}
+ org.hamcrest
+ hamcrest-librarytest
- junit
- junit
+ org.mockito
+ mockito-coretest
- org.hamcrest
- hamcrest-core
+ org.scalacheck
+ scalacheck_${scala.binary.version}test
- org.hamcrest
- hamcrest-library
+ junit
+ junittest
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
similarity index 91%
rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
index 3f746b886bc9b..0399abc63c235 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
+++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.unsafe;
+package org.apache.spark.serializer;
import java.io.IOException;
import java.io.InputStream;
@@ -24,9 +24,7 @@
import scala.reflect.ClassTag;
-import org.apache.spark.serializer.DeserializationStream;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.PlatformDependent;
/**
@@ -35,7 +33,8 @@
* `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
* around this, we pass a dummy no-op serializer.
*/
-final class DummySerializerInstance extends SerializerInstance {
+@Private
+public final class DummySerializerInstance extends SerializerInstance {
public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 9e9ed94b7890c..56289573209fb 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -30,6 +30,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
new file mode 100644
index 0000000000000..45b78829e4cf7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
+ * comparisons, such as lexicographic comparison for strings.
+ */
+@Private
+public abstract class PrefixComparator {
+ public abstract int compare(long prefix1, long prefix2);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
new file mode 100644
index 0000000000000..438742565c51d
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.google.common.base.Charsets;
+import com.google.common.primitives.Longs;
+import com.google.common.primitives.UnsignedBytes;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.types.UTF8String;
+
+@Private
+public class PrefixComparators {
+ private PrefixComparators() {}
+
+ public static final StringPrefixComparator STRING = new StringPrefixComparator();
+ public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
+ public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+ public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+
+ public static final class StringPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ // TODO: can done more efficiently
+ byte[] a = Longs.toByteArray(aPrefix);
+ byte[] b = Longs.toByteArray(bPrefix);
+ for (int i = 0; i < 8; i++) {
+ int c = UnsignedBytes.compare(a[i], b[i]);
+ if (c != 0) return c;
+ }
+ return 0;
+ }
+
+ public long computePrefix(byte[] bytes) {
+ if (bytes == null) {
+ return 0L;
+ } else {
+ byte[] padded = new byte[8];
+ System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8));
+ return Longs.fromByteArray(padded);
+ }
+ }
+
+ public long computePrefix(String value) {
+ return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8));
+ }
+
+ public long computePrefix(UTF8String value) {
+ return value == null ? 0L : computePrefix(value.getBytes());
+ }
+ }
+
+ /**
+ * Prefix comparator for all integral types (boolean, byte, short, int, long).
+ */
+ public static final class IntegralPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long a, long b) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public final long NULL_PREFIX = Long.MIN_VALUE;
+ }
+
+ public static final class FloatPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ float a = Float.intBitsToFloat((int) aPrefix);
+ float b = Float.intBitsToFloat((int) bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(float value) {
+ return Float.floatToIntBits(value) & 0xffffffffL;
+ }
+
+ public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
+ }
+
+ public static final class DoublePrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long aPrefix, long bPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+
+ public long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
+
+ public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
new file mode 100644
index 0000000000000..09e4258792204
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+/**
+ * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
+ * prefix, this may simply return 0.
+ */
+public abstract class RecordComparator {
+
+ /**
+ * Compare two records for order.
+ *
+ * @return a negative integer, zero, or a positive integer as the first record is less than,
+ * equal to, or greater than the second.
+ */
+ public abstract int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset);
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
new file mode 100644
index 0000000000000..0c4ebde407cfc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+final class RecordPointerAndKeyPrefix {
+ /**
+ * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+ * description of how these addresses are encoded.
+ */
+ public long recordPointer;
+
+ /**
+ * A key prefix, for use in comparisons.
+ */
+ public long keyPrefix;
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
new file mode 100644
index 0000000000000..4d6731ee60af3
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * External sorter based on {@link UnsafeInMemorySorter}.
+ */
+public final class UnsafeExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+
+ private static final int PAGE_SIZE = 1 << 27; // 128 megabytes
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final PrefixComparator prefixComparator;
+ private final RecordComparator recordComparator;
+ private final int initialSize;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList allocatedPages = new LinkedList();
+
+ // These variables are reset after spilling:
+ private UnsafeInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ private final LinkedList spillWriters = new LinkedList<>();
+
+ public UnsafeExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ SparkConf conf) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.initialSize = initialSize;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ initializeForWriting();
+ }
+
+ // TODO: metrics tracking + integration with shuffle write metrics
+ // need to connect the write metrics to task metrics so we count the spill IO somewhere.
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ this.writeMetrics = new ShuffleWriteMetrics();
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter =
+ new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ public void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size(),
+ spillWriters.size() > 1 ? " times" : " time");
+
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ sorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ public long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ // TODO: merge these steps to first calculate total memory requirements for this insert,
+ // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
+ // data page.
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ long prefix) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+
+ sorter.insertRecord(recordAddress, prefix);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+ if (spillWriters.isEmpty()) {
+ return inMemoryIterator;
+ } else {
+ final UnsafeSorterSpillMerger spillMerger =
+ new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+ for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ spillMerger.addSpill(spillWriter.getReader(blockManager));
+ }
+ spillWriters.clear();
+ if (inMemoryIterator.hasNext()) {
+ spillMerger.addSpill(inMemoryIterator);
+ }
+ return spillMerger.getSortedIterator();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
new file mode 100644
index 0000000000000..fc34ad9cff369
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
+ * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm
+ * compares records, it will first compare the stored key prefixes; if the prefixes are not equal,
+ * then we do not need to traverse the record pointers to compare the actual records. Avoiding these
+ * random memory accesses improves cache hit rates.
+ */
+public final class UnsafeInMemorySorter {
+
+ private static final class SortComparator implements Comparator {
+
+ private final RecordComparator recordComparator;
+ private final PrefixComparator prefixComparator;
+ private final TaskMemoryManager memoryManager;
+
+ SortComparator(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ TaskMemoryManager memoryManager) {
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.memoryManager = memoryManager;
+ }
+
+ @Override
+ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
+ final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
+ if (prefixComparisonResult == 0) {
+ final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
+ final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length
+ final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
+ final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length
+ return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ }
+
+ private final TaskMemoryManager memoryManager;
+ private final Sorter sorter;
+ private final Comparator sortComparator;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeInMemorySorter(
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize * 2];
+ this.memoryManager = memoryManager;
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pointerArrayInsertPosition / 2;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 2 < pointerArray.length;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a 4-byte integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix) {
+ if (!hasSpaceForAnotherRecord()) {
+ expandPointerArray();
+ }
+ pointerArray[pointerArrayInsertPosition] = recordPointer;
+ pointerArrayInsertPosition++;
+ pointerArray[pointerArrayInsertPosition] = keyPrefix;
+ pointerArrayInsertPosition++;
+ }
+
+ private static final class SortedIterator extends UnsafeSorterIterator {
+
+ private final TaskMemoryManager memoryManager;
+ private final int sortBufferInsertPosition;
+ private final long[] sortBuffer;
+ private int position = 0;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+
+ SortedIterator(
+ TaskMemoryManager memoryManager,
+ int sortBufferInsertPosition,
+ long[] sortBuffer) {
+ this.memoryManager = memoryManager;
+ this.sortBufferInsertPosition = sortBufferInsertPosition;
+ this.sortBuffer = sortBuffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < sortBufferInsertPosition;
+ }
+
+ @Override
+ public void loadNext() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortBuffer[position];
+ baseObject = memoryManager.getPage(recordPointer);
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
+ recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ keyPrefix = sortBuffer[position + 1];
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
+ return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
new file mode 100644
index 0000000000000..d09c728a7a638
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+/**
+ * Supports sorting an array of (record pointer, key prefix) pairs.
+ * Used in {@link UnsafeInMemorySorter}.
+ *
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+final class UnsafeSortDataFormat extends SortDataFormat {
+
+ public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+
+ private UnsafeSortDataFormat() { }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix newKey() {
+ return new RecordPointerAndKeyPrefix();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data[pos * 2];
+ reuse.keyPrefix = data[pos * 2 + 1];
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ long tempPointer = data[pos0 * 2];
+ long tempKeyPrefix = data[pos0 * 2 + 1];
+ data[pos0 * 2] = data[pos1 * 2];
+ data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
+ data[pos1 * 2] = tempPointer;
+ data[pos1 * 2 + 1] = tempKeyPrefix;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos * 2] = src[srcPos * 2];
+ dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
+ return new long[length * 2];
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
similarity index 65%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
rename to core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
index 8df4f3b554c41..16ac2e8d821ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -15,17 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster.mesos
+package org.apache.spark.util.collection.unsafe.sort;
-import org.apache.spark.SparkContext
+import java.io.IOException;
-private[spark] object MemoryUtils {
- // These defaults copied from YARN
- val OVERHEAD_FRACTION = 0.10
- val OVERHEAD_MINIMUM = 384
+public abstract class UnsafeSorterIterator {
- def calculateTotalMemory(sc: SparkContext): Int = {
- sc.conf.getInt("spark.mesos.executor.memoryOverhead",
- math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory
- }
+ public abstract boolean hasNext();
+
+ public abstract void loadNext() throws IOException;
+
+ public abstract Object getBaseObject();
+
+ public abstract long getBaseOffset();
+
+ public abstract int getRecordLength();
+
+ public abstract long getKeyPrefix();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
new file mode 100644
index 0000000000000..8272c2a5be0d1
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.PriorityQueue;
+
+final class UnsafeSorterSpillMerger {
+
+ private final PriorityQueue priorityQueue;
+
+ public UnsafeSorterSpillMerger(
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ final int numSpills) {
+ final Comparator comparator = new Comparator() {
+
+ @Override
+ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
+ final int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(),
+ right.getBaseObject(), right.getBaseOffset());
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ };
+ priorityQueue = new PriorityQueue(numSpills, comparator);
+ }
+
+ public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ }
+ priorityQueue.add(spillReader);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ return new UnsafeSorterIterator() {
+
+ private UnsafeSorterIterator spillReader;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ if (spillReader != null) {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ priorityQueue.add(spillReader);
+ }
+ }
+ spillReader = priorityQueue.remove();
+ }
+
+ @Override
+ public Object getBaseObject() { return spillReader.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return spillReader.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return spillReader.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return spillReader.getKeyPrefix(); }
+ };
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
new file mode 100644
index 0000000000000..29e9e0f30f934
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.*;
+
+import com.google.common.io.ByteStreams;
+
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
+ * of the file format).
+ */
+final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every record read:
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+
+ public UnsafeSorterSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ recordLength = din.readInt();
+ keyPrefix = din.readLong();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ ByteStreams.readFully(in, arr, 0, recordLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
new file mode 100644
index 0000000000000..b8d66659804ad
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Tuple2;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Spills a list of sorted records to disk. Spill files have the following format:
+ *
+ * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
+ */
+final class UnsafeSorterSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private BlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeLong.
+ private void writeLongToBuffer(long v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 56);
+ writeBuffer[offset + 1] = (byte)(v >>> 48);
+ writeBuffer[offset + 2] = (byte)(v >>> 40);
+ writeBuffer[offset + 3] = (byte)(v >>> 32);
+ writeBuffer[offset + 4] = (byte)(v >>> 24);
+ writeBuffer[offset + 5] = (byte)(v >>> 16);
+ writeBuffer[offset + 6] = (byte)(v >>> 8);
+ writeBuffer[offset + 7] = (byte)(v >>> 0);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ * @param keyPrefix a sort key prefix
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength,
+ long keyPrefix) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ writeIntToBuffer(recordLength, 0);
+ writeLongToBuffer(keyPrefix, 4);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ writer.recordWritten();
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
+ return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 49329423dca76..0c50b4002cf7b 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -102,7 +102,7 @@ private[spark] class ExecutorAllocationManager(
"spark.dynamicAllocation.executorIdleTimeout", "60s")
private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds(
- "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s")
+ "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s")
// During testing, the methods to actually kill and add executors are mocked out
private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 6909015ff66e6..221b1dab43278 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -24,8 +24,8 @@ import scala.collection.mutable
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.scheduler._
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet
private[spark] case object ExpireDeadHosts
+private case class ExecutorRegistered(executorId: String)
+
+private case class ExecutorRemoved(executorId: String)
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
* Lives in the driver to receive heartbeats from executors..
*/
-private[spark] class HeartbeatReceiver(sc: SparkContext)
- extends ThreadSafeRpcEndpoint with Logging {
+private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
+ extends ThreadSafeRpcEndpoint with SparkListener with Logging {
+
+ def this(sc: SparkContext) {
+ this(sc, new SystemClock)
+ }
+
+ sc.addSparkListener(this)
override val rpcEnv: RpcEnv = sc.env.rpcEnv
@@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
override def onStart(): Unit = {
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- Option(self).foreach(_.send(ExpireDeadHosts))
+ Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
}
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}
- override def receive: PartialFunction[Any, Unit] = {
- case ExpireDeadHosts =>
- expireDeadHosts()
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+
+ // Messages sent and received locally
+ case ExecutorRegistered(executorId) =>
+ executorLastSeen(executorId) = clock.getTimeMillis()
+ context.reply(true)
+ case ExecutorRemoved(executorId) =>
+ executorLastSeen.remove(executorId)
+ context.reply(true)
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
- }
+ context.reply(true)
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+ context.reply(true)
- override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ // Messages received from executors
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
- executorLastSeen(executorId) = System.currentTimeMillis()
- eventLoopThread.submit(new Runnable {
- override def run(): Unit = Utils.tryLogNonFatalError {
- val unknownExecutor = !scheduler.executorHeartbeatReceived(
- executorId, taskMetrics, blockManagerId)
- val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
- context.reply(response)
- }
- })
+ if (executorLastSeen.contains(executorId)) {
+ executorLastSeen(executorId) = clock.getTimeMillis()
+ eventLoopThread.submit(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ val unknownExecutor = !scheduler.executorHeartbeatReceived(
+ executorId, taskMetrics, blockManagerId)
+ val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
+ context.reply(response)
+ }
+ })
+ } else {
+ // This may happen if we get an executor's in-flight heartbeat immediately
+ // after we just removed it. It's not really an error condition so we should
+ // not log warning here. Otherwise there may be a lot of noise especially if
+ // we explicitly remove executors (SPARK-4134).
+ logDebug(s"Received heartbeat from unknown executor $executorId")
+ context.reply(HeartbeatResponse(reregisterBlockManager = true))
+ }
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
@@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}
+ /**
+ * If the heartbeat receiver is not stopped, notify it of executor registrations.
+ */
+ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = {
+ Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId)))
+ }
+
+ /**
+ * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't
+ * log superfluous errors.
+ *
+ * Note that we must do this after the executor is actually removed to guard against the
+ * following race condition: if we remove an executor's metadata from our data structure
+ * prematurely, we may get an in-flight heartbeat from the executor before the executor is
+ * actually removed, in which case we will still mark the executor as a dead host later
+ * and expire it with loud error messages.
+ */
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
+ Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId)))
+ }
+
private def expireDeadHosts(): Unit = {
logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
- val now = System.currentTimeMillis()
+ val now = clock.getTimeMillis()
for ((executorId, lastSeenMs) <- executorLastSeen) {
if (now - lastSeenMs > executorTimeoutMs) {
logWarning(s"Removing executor $executorId with no recent heartbeats: " +
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 7fcb7830e7b0b..87ab099267b2f 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -121,6 +121,7 @@ trait Logging {
if (usingLog4j12) {
val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4j12Initialized) {
+ // scalastyle:off println
if (Utils.isInInterpreter) {
val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties"
Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match {
@@ -141,6 +142,7 @@ trait Logging {
System.err.println(s"Spark was unable to load $defaultLogProps")
}
}
+ // scalastyle:on println
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala
index 2cdc167f85af0..32df42d57dbd6 100644
--- a/core/src/main/scala/org/apache/spark/SSLOptions.scala
+++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala
@@ -17,7 +17,9 @@
package org.apache.spark
-import java.io.File
+import java.io.{File, FileInputStream}
+import java.security.{KeyStore, NoSuchAlgorithmException}
+import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory}
import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory}
import org.eclipse.jetty.util.ssl.SslContextFactory
@@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory
* @param trustStore a path to the trust-store file
* @param trustStorePassword a password to access the trust-store file
* @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java
- * @param enabledAlgorithms a set of encryption algorithms to use
+ * @param enabledAlgorithms a set of encryption algorithms that may be used
*/
private[spark] case class SSLOptions(
enabled: Boolean = false,
@@ -48,7 +50,8 @@ private[spark] case class SSLOptions(
trustStore: Option[File] = None,
trustStorePassword: Option[String] = None,
protocol: Option[String] = None,
- enabledAlgorithms: Set[String] = Set.empty) {
+ enabledAlgorithms: Set[String] = Set.empty)
+ extends Logging {
/**
* Creates a Jetty SSL context factory according to the SSL settings represented by this object.
@@ -63,7 +66,7 @@ private[spark] case class SSLOptions(
trustStorePassword.foreach(sslContextFactory.setTrustStorePassword)
keyPassword.foreach(sslContextFactory.setKeyManagerPassword)
protocol.foreach(sslContextFactory.setProtocol)
- sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*)
+ sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*)
Some(sslContextFactory)
} else {
@@ -94,7 +97,7 @@ private[spark] case class SSLOptions(
.withValue("akka.remote.netty.tcp.security.protocol",
ConfigValueFactory.fromAnyRef(protocol.getOrElse("")))
.withValue("akka.remote.netty.tcp.security.enabled-algorithms",
- ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq))
+ ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq))
.withValue("akka.remote.netty.tcp.enable-ssl",
ConfigValueFactory.fromAnyRef(true)))
} else {
@@ -102,6 +105,36 @@ private[spark] case class SSLOptions(
}
}
+ /*
+ * The supportedAlgorithms set is a subset of the enabledAlgorithms that
+ * are supported by the current Java security provider for this protocol.
+ */
+ private val supportedAlgorithms: Set[String] = {
+ var context: SSLContext = null
+ try {
+ context = SSLContext.getInstance(protocol.orNull)
+ /* The set of supported algorithms does not depend upon the keys, trust, or
+ rng, although they will influence which algorithms are eventually used. */
+ context.init(null, null, null)
+ } catch {
+ case npe: NullPointerException =>
+ logDebug("No SSL protocol specified")
+ context = SSLContext.getDefault
+ case nsa: NoSuchAlgorithmException =>
+ logDebug(s"No support for requested SSL protocol ${protocol.get}")
+ context = SSLContext.getDefault
+ }
+
+ val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet
+
+ // Log which algorithms we are discarding
+ (enabledAlgorithms &~ providerAlgorithms).foreach { cipher =>
+ logDebug(s"Discarding unsupported cipher $cipher")
+ }
+
+ enabledAlgorithms & providerAlgorithms
+ }
+
/** Returns a string representation of this SSLOptions with all the passwords masked. */
override def toString: String = s"SSLOptions{enabled=$enabled, " +
s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " +
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c7a7436462083..82704b1ab2189 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_dagScheduler = ds
}
+ /**
+ * A unique identifier for the Spark application.
+ * Its format depends on the scheduler implementation.
+ * (i.e.
+ * in case of local spark app something like 'local-1433865536131'
+ * in case of YARN something like 'application_1433865536131_34483'
+ * )
+ */
def applicationId: String = _applicationId
def applicationAttemptId: Option[String] = _applicationAttemptId
@@ -490,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_schedulerBackend = sched
_taskScheduler = ts
_dagScheduler = new DAGScheduler(this)
- _heartbeatReceiver.send(TaskSchedulerIsSet)
+ _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)
// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
// constructor
@@ -524,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_executorAllocationManager =
if (dynamicAllocationEnabled) {
assert(supportDynamicAllocation,
- "Dynamic allocation of executors is currently only supported in YARN mode")
+ "Dynamic allocation of executors is currently only supported in YARN and Mesos mode")
Some(new ExecutorAllocationManager(this, listenerBus, _conf))
} else {
None
@@ -823,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* }}}
*
* @note Small files are preferred, large file is also allowable, but may cause bad performance.
- *
+ * @note On some filesystems, `.../path/*` can be a more efficient way to read all files
+ * in a directory rather than `.../path/` or `.../path`
* @param minPartitions A suggestion value of the minimal splitting number for input data.
*/
def wholeTextFiles(
@@ -844,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
minPartitions).setName(path)
}
-
/**
* :: Experimental ::
*
@@ -870,9 +878,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*
- * @param minPartitions A suggestion value of the minimal splitting number for input data.
- *
* @note Small files are preferred; very large files may cause bad performance.
+ * @note On some filesystems, `.../path/*` can be a more efficient way to read all files
+ * in a directory rather than `.../path/` or `.../path`
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
*/
@Experimental
def binaryFiles(
@@ -1354,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* Return whether dynamically adjusting the amount of resources allocated to
- * this application is supported. This is currently only available for YARN.
+ * this application is supported. This is currently only available for YARN
+ * and Mesos coarse-grained mode.
*/
- private[spark] def supportDynamicAllocation =
- master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false)
+ private[spark] def supportDynamicAllocation: Boolean = {
+ (master.contains("yarn")
+ || master.contains("mesos")
+ || _conf.getBoolean("spark.dynamicAllocation.testing", false))
+ }
/**
* :: DeveloperApi ::
@@ -1375,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = {
assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN mode")
+ "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestTotalExecutors(numExecutors)
@@ -1393,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN mode")
+ "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestExecutors(numAdditionalExecutors)
@@ -1411,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
assert(supportDynamicAllocation,
- "Killing executors is currently only supported in YARN mode")
+ "Killing executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(executorIds)
@@ -1896,6 +1909,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* be a HDFS path if running on a cluster.
*/
def setCheckpointDir(directory: String) {
+
+ // If we are running on a cluster, log a warning if the directory is local.
+ // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
+ // its own local file system, which is incorrect because the checkpoint files
+ // are actually on the executor machines.
+ if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
+ logWarning("Checkpoint directory must be non-local " +
+ "if Spark is running on a cluster: " + directory)
+ }
+
checkpointDir = Option(directory).map { dir =>
val path = new Path(dir, UUID.randomUUID().toString)
val fs = path.getFileSystem(hadoopConfiguration)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b0665570e2681..d18fc599e9890 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -22,7 +22,6 @@ import java.net.Socket
import akka.actor.ActorSystem
-import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Properties
@@ -77,7 +76,7 @@ class SparkEnv (
val conf: SparkConf) extends Logging {
// TODO Remove actorSystem
- @deprecated("Actor system is no longer supported as of 1.4")
+ @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0")
val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
private[spark] var isStopped = false
@@ -90,39 +89,42 @@ class SparkEnv (
private var driverTmpDirToDelete: Option[String] = None
private[spark] def stop() {
- isStopped = true
- pythonWorkers.foreach { case(key, worker) => worker.stop() }
- Option(httpFileServer).foreach(_.stop())
- mapOutputTracker.stop()
- shuffleManager.stop()
- broadcastManager.stop()
- blockManager.stop()
- blockManager.master.stop()
- metricsSystem.stop()
- outputCommitCoordinator.stop()
- rpcEnv.shutdown()
-
- // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
- // down, but let's call it anyway in case it gets fixed in a later release
- // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
- // actorSystem.awaitTermination()
-
- // Note that blockTransferService is stopped by BlockManager since it is started by it.
-
- // If we only stop sc, but the driver process still run as a services then we need to delete
- // the tmp dir, if not, it will create too many tmp dirs.
- // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
- // current working dir in executor which we do not need to delete.
- driverTmpDirToDelete match {
- case Some(path) => {
- try {
- Utils.deleteRecursively(new File(path))
- } catch {
- case e: Exception =>
- logWarning(s"Exception while deleting Spark temp dir: $path", e)
+
+ if (!isStopped) {
+ isStopped = true
+ pythonWorkers.values.foreach(_.stop())
+ Option(httpFileServer).foreach(_.stop())
+ mapOutputTracker.stop()
+ shuffleManager.stop()
+ broadcastManager.stop()
+ blockManager.stop()
+ blockManager.master.stop()
+ metricsSystem.stop()
+ outputCommitCoordinator.stop()
+ rpcEnv.shutdown()
+
+ // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
+ // down, but let's call it anyway in case it gets fixed in a later release
+ // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
+ // actorSystem.awaitTermination()
+
+ // Note that blockTransferService is stopped by BlockManager since it is started by it.
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
+ // the tmp dir, if not, it will create too many tmp dirs.
+ // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
+ // current working dir in executor which we do not need to delete.
+ driverTmpDirToDelete match {
+ case Some(path) => {
+ try {
+ Utils.deleteRecursively(new File(path))
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception while deleting Spark temp dir: $path", e)
+ }
}
+ case None => // We just need to delete tmp dir created by driver, so do nothing on executor
}
- case None => // We just need to delete tmp dir created by driver, so do nothing on executor
}
}
@@ -171,7 +173,7 @@ object SparkEnv extends Logging {
/**
* Returns the ThreadLocal SparkEnv.
*/
- @deprecated("Use SparkEnv.get instead", "1.2")
+ @deprecated("Use SparkEnv.get instead", "1.2.0")
def getThreadLocal: SparkEnv = {
env
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 1a5f2bca26c2b..b7e72d4d0ed0b 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -95,7 +95,9 @@ private[spark] class RBackend {
private[spark] object RBackend extends Logging {
def main(args: Array[String]): Unit = {
if (args.length < 1) {
+ // scalastyle:off println
System.err.println("Usage: RBackend ")
+ // scalastyle:on println
System.exit(-1)
}
val sparkRBackend = new RBackend()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 4dfa7325934ff..23a470d6afcae 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
protected var dataStream: DataInputStream = _
@@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The stdout/stderr is shared by multiple tasks, because we use one daemon
// to launch child process as worker.
- val errThread = RRDD.createRWorker(rLibDir, listenPort)
+ val errThread = RRDD.createRWorker(listenPort)
// We use two sockets to separate input and output, then it's easy to manage
// the lifecycle of them to avoid deadlock.
@@ -161,7 +160,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
dataOut.write(elem.asInstanceOf[Array[Byte]])
} else if (deserializer == SerializationFormats.STRING) {
// write string(for StringRRDD)
+ // scalastyle:off println
printOut.println(elem)
+ // scalastyle:on println
}
}
@@ -233,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag](
hashFunc: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, (Int, Array[Byte])](
parent, numPartitions, hashFunc, deserializer,
- SerializationFormats.BYTE, packageNames, rLibDir,
+ SerializationFormats.BYTE, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): (Int, Array[Byte]) = {
@@ -264,10 +264,9 @@ private class RRDD[T: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, Array[Byte]](
- parent, -1, func, deserializer, serializer, packageNames, rLibDir,
+ parent, -1, func, deserializer, serializer, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): Array[Byte] = {
@@ -291,10 +290,9 @@ private class StringRRDD[T: ClassTag](
func: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, String](
- parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
+ parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): String = {
@@ -390,9 +388,10 @@ private[r] object RRDD {
thread
}
- private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
- val rCommand = "Rscript"
+ private def createRProcess(port: Int, script: String): BufferedStreamThread = {
+ val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
val rOptions = "--vanilla"
+ val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir + "/SparkR/worker/" + script
val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
// Unset the R_TESTS environment variable for workers.
@@ -411,7 +410,7 @@ private[r] object RRDD {
/**
* ProcessBuilder used to launch worker R processes.
*/
- def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
+ def createRWorker(port: Int): BufferedStreamThread = {
val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
if (!Utils.isWindows && useDaemon) {
synchronized {
@@ -419,7 +418,7 @@ private[r] object RRDD {
// we expect one connections
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val daemonPort = serverSocket.getLocalPort
- errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
+ errThread = createRProcess(daemonPort, "daemon.R")
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
val sock = serverSocket.accept()
@@ -441,7 +440,7 @@ private[r] object RRDD {
errThread
}
} else {
- createRProcess(rLibDir, port, "worker.R")
+ createRProcess(port, "worker.R")
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
new file mode 100644
index 0000000000000..d53abd3408c55
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.File
+
+import org.apache.spark.{SparkEnv, SparkException}
+
+private[spark] object RUtils {
+ /**
+ * Get the SparkR package path in the local spark distribution.
+ */
+ def localSparkRPackagePath: Option[String] = {
+ val sparkHome = sys.env.get("SPARK_HOME")
+ sparkHome.map(
+ Seq(_, "R", "lib").mkString(File.separator)
+ )
+ }
+
+ /**
+ * Get the SparkR package path in various deployment modes.
+ * This assumes that Spark properties `spark.master` and `spark.submit.deployMode`
+ * and environment variable `SPARK_HOME` are set.
+ */
+ def sparkRPackagePath(isDriver: Boolean): String = {
+ val (master, deployMode) =
+ if (isDriver) {
+ (sys.props("spark.master"), sys.props("spark.submit.deployMode"))
+ } else {
+ val sparkConf = SparkEnv.get.conf
+ (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode"))
+ }
+
+ val isYarnCluster = master.contains("yarn") && deployMode == "cluster"
+ val isYarnClient = master.contains("yarn") && deployMode == "client"
+
+ // In YARN mode, the SparkR package is distributed as an archive symbolically
+ // linked to the "sparkr" file in the current directory. Note that this does not apply
+ // to the driver in client mode because it is run outside of the cluster.
+ if (isYarnCluster || (isYarnClient && !isDriver)) {
+ new File("sparkr").getAbsolutePath
+ } else {
+ // Otherwise, assume the package is local
+ // TODO: support this for Mesos
+ localSparkRPackagePath.getOrElse {
+ throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 848b62f9de71b..f03875a3e8c89 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -18,17 +18,17 @@
package org.apache.spark.deploy
import scala.collection.mutable.HashSet
-import scala.concurrent._
+import scala.concurrent.ExecutionContext
+import scala.reflect.ClassTag
+import scala.util.{Failure, Success}
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils}
/**
* Proxy that relays messages to the driver.
@@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils}
* We currently don't support retry if submission fails. In HA mode, client will submit request to
* all masters and see which one could handle it.
*/
-private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
- extends Actor with ActorLogReceive with Logging {
-
- private val masterActors = driverArgs.masters.map { m =>
- context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system)))
- }
- private val lostMasters = new HashSet[Address]
- private var activeMasterActor: ActorSelection = null
-
- val timeout = RpcUtils.askTimeout(conf)
-
- override def preStart(): Unit = {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
-
+private class ClientEndpoint(
+ override val rpcEnv: RpcEnv,
+ driverArgs: ClientArguments,
+ masterEndpoints: Seq[RpcEndpointRef],
+ conf: SparkConf)
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ // A scheduled executor used to send messages at the specified time.
+ private val forwardMessageThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message")
+ // Used to provide the implicit parameter of `Future` methods.
+ private val forwardMessageExecutionContext =
+ ExecutionContext.fromExecutor(forwardMessageThread,
+ t => t match {
+ case ie: InterruptedException => // Exit normally
+ case e: Throwable =>
+ logError(e.getMessage, e)
+ System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
+ })
+
+ private val lostMasters = new HashSet[RpcAddress]
+ private var activeMasterEndpoint: RpcEndpointRef = null
+
+ override def onStart(): Unit = {
driverArgs.cmd match {
case "launch" =>
// TODO: We could add an env variable here and intercept it in `sc.addJar` that would
@@ -82,44 +92,52 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
driverArgs.cores,
driverArgs.supervise,
command)
-
- // This assumes only one Master is active at a time
- for (masterActor <- masterActors) {
- masterActor ! RequestSubmitDriver(driverDescription)
- }
+ ayncSendToMasterAndForwardReply[SubmitDriverResponse](
+ RequestSubmitDriver(driverDescription))
case "kill" =>
val driverId = driverArgs.driverId
- // This assumes only one Master is active at a time
- for (masterActor <- masterActors) {
- masterActor ! RequestKillDriver(driverId)
- }
+ ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId))
+ }
+ }
+
+ /**
+ * Send the message to master and forward the reply to self asynchronously.
+ */
+ private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = {
+ for (masterEndpoint <- masterEndpoints) {
+ masterEndpoint.ask[T](message).onComplete {
+ case Success(v) => self.send(v)
+ case Failure(e) =>
+ logWarning(s"Error sending messages to master $masterEndpoint", e)
+ }(forwardMessageExecutionContext)
}
}
/* Find out driver status then exit the JVM */
def pollAndReportStatus(driverId: String) {
- println("... waiting before polling master for driver state")
+ // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread
+ // is fine.
+ logInfo("... waiting before polling master for driver state")
Thread.sleep(5000)
- println("... polling master for driver state")
- val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout)
- .mapTo[DriverStatusResponse]
- val statusResponse = Await.result(statusFuture, timeout)
+ logInfo("... polling master for driver state")
+ val statusResponse =
+ activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId))
statusResponse.found match {
case false =>
- println(s"ERROR: Cluster master did not recognize $driverId")
+ logError(s"ERROR: Cluster master did not recognize $driverId")
System.exit(-1)
case true =>
- println(s"State of $driverId is ${statusResponse.state.get}")
+ logInfo(s"State of $driverId is ${statusResponse.state.get}")
// Worker node, if present
(statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match {
case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) =>
- println(s"Driver running on $hostPort ($id)")
+ logInfo(s"Driver running on $hostPort ($id)")
case _ =>
}
// Exception, if present
statusResponse.exception.map { e =>
- println(s"Exception from cluster was: $e")
+ logError(s"Exception from cluster was: $e")
e.printStackTrace()
System.exit(-1)
}
@@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
}
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
- case SubmitDriverResponse(success, driverId, message) =>
- println(message)
+ case SubmitDriverResponse(master, success, driverId, message) =>
+ logInfo(message)
if (success) {
- activeMasterActor = context.actorSelection(sender.path)
+ activeMasterEndpoint = master
pollAndReportStatus(driverId.get)
} else if (!Utils.responseFromBackup(message)) {
System.exit(-1)
}
- case KillDriverResponse(driverId, success, message) =>
- println(message)
+ case KillDriverResponse(master, driverId, success, message) =>
+ logInfo(message)
if (success) {
- activeMasterActor = context.actorSelection(sender.path)
+ activeMasterEndpoint = master
pollAndReportStatus(driverId)
} else if (!Utils.responseFromBackup(message)) {
System.exit(-1)
}
+ }
- case DisassociatedEvent(_, remoteAddress, _) =>
- if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master $remoteAddress.")
- lostMasters += remoteAddress
- // Note that this heuristic does not account for the fact that a Master can recover within
- // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
- // is not currently a concern, however, because this client does not retry submissions.
- if (lostMasters.size >= masterActors.size) {
- println("No master is available, exiting.")
- System.exit(-1)
- }
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (!lostMasters.contains(remoteAddress)) {
+ logError(s"Error connecting to master $remoteAddress.")
+ lostMasters += remoteAddress
+ // Note that this heuristic does not account for the fact that a Master can recover within
+ // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This
+ // is not currently a concern, however, because this client does not retry submissions.
+ if (lostMasters.size >= masterEndpoints.size) {
+ logError("No master is available, exiting.")
+ System.exit(-1)
}
+ }
+ }
- case AssociationErrorEvent(cause, _, remoteAddress, _, _) =>
- if (!lostMasters.contains(remoteAddress)) {
- println(s"Error connecting to master ($remoteAddress).")
- println(s"Cause was: $cause")
- lostMasters += remoteAddress
- if (lostMasters.size >= masterActors.size) {
- println("No master is available, exiting.")
- System.exit(-1)
- }
+ override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
+ if (!lostMasters.contains(remoteAddress)) {
+ logError(s"Error connecting to master ($remoteAddress).")
+ logError(s"Cause was: $cause")
+ lostMasters += remoteAddress
+ if (lostMasters.size >= masterEndpoints.size) {
+ logError("No master is available, exiting.")
+ System.exit(-1)
}
+ }
+ }
+
+ override def onError(cause: Throwable): Unit = {
+ logError(s"Error processing messages, exiting.")
+ cause.printStackTrace()
+ System.exit(-1)
+ }
+
+ override def onStop(): Unit = {
+ forwardMessageThread.shutdownNow()
}
}
@@ -179,10 +209,12 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
*/
object Client {
def main(args: Array[String]) {
+ // scalastyle:off println
if (!sys.props.contains("SPARK_SUBMIT")) {
println("WARNING: This client is deprecated and will be removed in a future version of Spark")
println("Use ./bin/spark-submit with \"--master spark://host:port\"")
}
+ // scalastyle:on println
val conf = new SparkConf()
val driverArgs = new ClientArguments(args)
@@ -194,15 +226,13 @@ object Client {
conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING"))
Logger.getRootLogger.setLevel(driverArgs.logLevel)
- val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
+ val rpcEnv =
+ RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
- // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
- for (m <- driverArgs.masters) {
- Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem))
- }
- actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
+ val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL).
+ map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME))
+ rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf))
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 316e2d59f01b8..72cc330a398da 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) {
cmd = "launch"
if (!ClientArguments.isValidJarUrl(_jarUrl)) {
+ // scalastyle:off println
println(s"Jar url '${_jarUrl}' is not in valid format.")
println(s"Must be a jar file path in URL format " +
"(e.g. hdfs://host:port/XX.jar, file:///XX.jar)")
+ // scalastyle:on println
printUsageAndExit(-1)
}
@@ -110,14 +112,16 @@ private[deploy] class ClientArguments(args: Array[String]) {
| (default: $DEFAULT_SUPERVISE)
| -v, --verbose Print more debugging output
""".stripMargin
+ // scalastyle:off println
System.err.println(usage)
+ // scalastyle:on println
System.exit(exitCode)
}
}
private[deploy] object ClientArguments {
val DEFAULT_CORES = 1
- val DEFAULT_MEMORY = 512 // MB
+ val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB
val DEFAULT_SUPERVISE = false
def isValidJarUrl(s: String): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 9db6fd1ac4dbe..12727de9b4cf3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
-/** Contains messages sent between Scheduler actor nodes. */
+/** Contains messages sent between Scheduler endpoint nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@@ -37,6 +38,7 @@ private[deploy] object DeployMessages {
id: String,
host: String,
port: Int,
+ worker: RpcEndpointRef,
cores: Int,
memory: Int,
webUiPort: Int,
@@ -63,11 +65,11 @@ private[deploy] object DeployMessages {
case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription],
driverIds: Seq[String])
- case class Heartbeat(workerId: String) extends DeployMessage
+ case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage
// Master to Worker
- case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
+ case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
@@ -92,13 +94,13 @@ private[deploy] object DeployMessages {
// Worker internal
- case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+ case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders
case object ReregisterWithMaster // used when a worker attempts to reconnect to a master
// AppClient to Master
- case class RegisterApplication(appDescription: ApplicationDescription)
+ case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef)
extends DeployMessage
case class UnregisterApplication(appId: String)
@@ -107,7 +109,7 @@ private[deploy] object DeployMessages {
// Master to AppClient
- case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
+ case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@@ -123,12 +125,14 @@ private[deploy] object DeployMessages {
case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage
- case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String)
+ case class SubmitDriverResponse(
+ master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String)
extends DeployMessage
case class RequestKillDriver(driverId: String) extends DeployMessage
- case class KillDriverResponse(driverId: String, success: Boolean, message: String)
+ case class KillDriverResponse(
+ master: RpcEndpointRef, driverId: String, success: Boolean, message: String)
extends DeployMessage
case class RequestDriverStatus(driverId: String) extends DeployMessage
@@ -142,7 +146,7 @@ private[deploy] object DeployMessages {
// Master to Worker & AppClient
- case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
+ case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String)
// MasterWebUI To Master
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 2954f932b4f41..ccffb36652988 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -76,12 +76,13 @@ private[deploy] object JsonProtocol {
}
def writeMasterState(obj: MasterStateResponse): JObject = {
+ val aliveWorkers = obj.workers.filter(_.isAlive())
("url" -> obj.uri) ~
("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~
- ("cores" -> obj.workers.map(_.cores).sum) ~
- ("coresused" -> obj.workers.map(_.coresUsed).sum) ~
- ("memory" -> obj.workers.map(_.memory).sum) ~
- ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
+ ("cores" -> aliveWorkers.map(_.cores).sum) ~
+ ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~
+ ("memory" -> aliveWorkers.map(_.memory).sum) ~
+ ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 0550f00a172ab..53356addf6edb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -19,8 +19,7 @@ package org.apache.spark.deploy
import scala.collection.mutable.ArrayBuffer
-import akka.actor.ActorSystem
-
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
@@ -41,8 +40,8 @@ class LocalSparkCluster(
extends Logging {
private val localHostname = Utils.localHostName()
- private val masterActorSystems = ArrayBuffer[ActorSystem]()
- private val workerActorSystems = ArrayBuffer[ActorSystem]()
+ private val masterRpcEnvs = ArrayBuffer[RpcEnv]()
+ private val workerRpcEnvs = ArrayBuffer[RpcEnv]()
// exposed for testing
var masterWebUIPort = -1
@@ -55,18 +54,17 @@ class LocalSparkCluster(
.set("spark.shuffle.service.enabled", "false")
/* Start the Master */
- val (masterSystem, masterPort, webUiPort, _) =
- Master.startSystemAndActor(localHostname, 0, 0, _conf)
+ val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf)
masterWebUIPort = webUiPort
- masterActorSystems += masterSystem
- val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort
+ masterRpcEnvs += rpcEnv
+ val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port
val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
+ val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masters, null, Some(workerNum), _conf)
- workerActorSystems += workerSystem
+ workerRpcEnvs += workerEnv
}
masters
@@ -77,11 +75,11 @@ class LocalSparkCluster(
// Stop the workers before the master so they don't get upset that it disconnected
// TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
// This is unfortunate, but for now we just comment it out.
- workerActorSystems.foreach(_.shutdown())
+ workerRpcEnvs.foreach(_.shutdown())
// workerActorSystems.foreach(_.awaitTermination())
- masterActorSystems.foreach(_.shutdown())
+ masterRpcEnvs.foreach(_.shutdown())
// masterActorSystems.foreach(_.awaitTermination())
- masterActorSystems.clear()
- workerActorSystems.clear()
+ masterRpcEnvs.clear()
+ workerRpcEnvs.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index e99779f299785..c0cab22fa8252 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
-import org.apache.spark.api.r.RBackend
+import org.apache.spark.api.r.{RBackend, RUtils}
import org.apache.spark.util.RedirectThread
/**
@@ -71,9 +71,10 @@ object RRunner {
val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
- val sparkHome = System.getenv("SPARK_HOME")
+ val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
+ env.put("SPARKR_PACKAGE_DIR", rPackageDir)
env.put("R_PROFILE_USER",
- Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator))
+ Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
@@ -85,7 +86,9 @@ object RRunner {
}
System.exit(returnCode)
} else {
+ // scalastyle:off println
System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds")
+ // scalastyle:on println
System.exit(-1)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 7fa75ac8c2b54..6d14590a1d192 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -334,6 +334,19 @@ class SparkHadoopUtil extends Logging {
* Stop the thread that does the delegation token updates.
*/
private[spark] def stopExecutorDelegationTokenRenewer() {}
+
+ /**
+ * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism.
+ * This is to prevent the DFSClient from using an old cached token to connect to the NameNode.
+ */
+ private[spark] def getConfBypassingFSCache(
+ hadoopConf: Configuration,
+ scheme: String): Configuration = {
+ val newConf = new Configuration(hadoopConf)
+ val confKey = s"fs.${scheme}.impl.disable.cache"
+ newConf.setBoolean(confKey, true)
+ newConf
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index abf222757a95b..7089a7e26707f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver}
+import org.apache.spark.api.r.RUtils
import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -79,9 +80,11 @@ object SparkSubmit {
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
private val SPARKR_SHELL = "sparkr-shell"
+ private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
+ // scalastyle:off println
// Exposed for testing
private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
private[spark] var printStream: PrintStream = System.err
@@ -102,11 +105,14 @@ object SparkSubmit {
printStream.println("Type --help for more information.")
exitFn(0)
}
+ // scalastyle:on println
def main(args: Array[String]): Unit = {
val appArgs = new SparkSubmitArguments(args)
if (appArgs.verbose) {
+ // scalastyle:off println
printStream.println(appArgs)
+ // scalastyle:on println
}
appArgs.action match {
case SparkSubmitAction.SUBMIT => submit(appArgs)
@@ -160,7 +166,9 @@ object SparkSubmit {
// makes the message printed to the output by the JVM not very helpful. Instead,
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
+ // scalastyle:off println
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
+ // scalastyle:on println
exitFn(1)
} else {
throw e
@@ -178,7 +186,9 @@ object SparkSubmit {
// to use the legacy gateway if the master endpoint turns out to be not a REST server.
if (args.isStandaloneCluster && args.useRest) {
try {
+ // scalastyle:off println
printStream.println("Running Spark using the REST application submission protocol.")
+ // scalastyle:on println
doRunMain()
} catch {
// Fail over to use the legacy submission gateway
@@ -254,6 +264,12 @@ object SparkSubmit {
}
}
+ // Update args.deployMode if it is null. It will be passed down as a Spark property later.
+ (args.deployMode, deployMode) match {
+ case (null, CLIENT) => args.deployMode = "client"
+ case (null, CLUSTER) => args.deployMode = "cluster"
+ case _ =>
+ }
val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER
@@ -339,6 +355,23 @@ object SparkSubmit {
}
}
+ // In YARN mode for an R app, add the SparkR package archive to archives
+ // that can be distributed with the job
+ if (args.isR && clusterManager == YARN) {
+ val rPackagePath = RUtils.localSparkRPackagePath
+ if (rPackagePath.isEmpty) {
+ printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.")
+ }
+ val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE)
+ if (!rPackageFile.exists()) {
+ printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.")
+ }
+ val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath)
+
+ // Assigns a symbol link name "sparkr" to the shipped package.
+ args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr")
+ }
+
// If we're running a R app, set the main class to our specific R runner
if (args.isR && deployMode == CLIENT) {
if (args.primaryResource == SPARKR_SHELL) {
@@ -367,6 +400,8 @@ object SparkSubmit {
// All cluster managers
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
+ OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
+ sysProp = "spark.submit.deployMode"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
@@ -558,6 +593,7 @@ object SparkSubmit {
sysProps: Map[String, String],
childMainClass: String,
verbose: Boolean): Unit = {
+ // scalastyle:off println
if (verbose) {
printStream.println(s"Main class:\n$childMainClass")
printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
@@ -565,6 +601,7 @@ object SparkSubmit {
printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}")
printStream.println("\n")
}
+ // scalastyle:on println
val loader =
if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
@@ -592,8 +629,10 @@ object SparkSubmit {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
+ // scalastyle:off println
printStream.println(s"Failed to load main class $childMainClass.")
printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.")
+ // scalastyle:on println
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
@@ -756,6 +795,22 @@ private[spark] object SparkSubmitUtils {
val cr = new ChainResolver
cr.setName("list")
+ val repositoryList = remoteRepos.getOrElse("")
+ // add any other remote repositories other than maven central
+ if (repositoryList.trim.nonEmpty) {
+ repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
+ val brr: IBiblioResolver = new IBiblioResolver
+ brr.setM2compatible(true)
+ brr.setUsepoms(true)
+ brr.setRoot(repo)
+ brr.setName(s"repo-${i + 1}")
+ cr.add(brr)
+ // scalastyle:off println
+ printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
+ // scalastyle:on println
+ }
+ }
+
val localM2 = new IBiblioResolver
localM2.setM2compatible(true)
localM2.setRoot(m2Path.toURI.toString)
@@ -786,20 +841,6 @@ private[spark] object SparkSubmitUtils {
sp.setRoot("http://dl.bintray.com/spark-packages/maven")
sp.setName("spark-packages")
cr.add(sp)
-
- val repositoryList = remoteRepos.getOrElse("")
- // add any other remote repositories other than maven central
- if (repositoryList.trim.nonEmpty) {
- repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
- val brr: IBiblioResolver = new IBiblioResolver
- brr.setM2compatible(true)
- brr.setUsepoms(true)
- brr.setRoot(repo)
- brr.setName(s"repo-${i + 1}")
- cr.add(brr)
- printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
- }
- }
cr
}
@@ -829,7 +870,9 @@ private[spark] object SparkSubmitUtils {
val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
val dd = new DefaultDependencyDescriptor(ri, false, false)
dd.addDependencyConfiguration(ivyConfName, ivyConfName)
+ // scalastyle:off println
printStream.println(s"${dd.getDependencyId} added as a dependency")
+ // scalastyle:on println
md.addDependency(dd)
}
}
@@ -896,9 +939,11 @@ private[spark] object SparkSubmitUtils {
ivySettings.setDefaultCache(new File(alternateIvyCache, "cache"))
new File(alternateIvyCache, "jars")
}
+ // scalastyle:off println
printStream.println(
s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}")
printStream.println(s"The jars for the packages stored in: $packagesDirectory")
+ // scalastyle:on println
// create a pattern matcher
ivySettings.addMatcher(new GlobPatternMatcher)
// create the dependency resolvers
@@ -922,6 +967,15 @@ private[spark] object SparkSubmitUtils {
// A Module descriptor must be specified. Entries are dummy strings
val md = getModuleDescriptor
+ // clear ivy resolution from previous launches. The resolution file is usually at
+ // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file
+ // leads to confusion with Ivy when the files can no longer be found at the repository
+ // declared in that file/
+ val mdId = md.getModuleRevisionId
+ val previousResolution = new File(ivySettings.getDefaultCache,
+ s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml")
+ if (previousResolution.exists) previousResolution.delete
+
md.setDefaultConf(ivyConfName)
// Add exclusion rules for Spark and Scala Library
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index b7429a901e162..ebb39c354dff1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
+ // scalastyle:off println
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
Utils.getPropertiesFromFile(filename).foreach { case (k, v) =>
@@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
}
}
+ // scalastyle:on println
defaultProperties
}
@@ -162,6 +164,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
.orNull
executorCores = Option(executorCores)
.orElse(sparkProperties.get("spark.executor.cores"))
+ .orElse(env.get("SPARK_EXECUTOR_CORES"))
.orNull
totalExecutorCores = Option(totalExecutorCores)
.orElse(sparkProperties.get("spark.cores.max"))
@@ -451,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
}
private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
+ // scalastyle:off println
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
@@ -461,8 +465,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
|Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin)
outStream.println(command)
+ val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB
outStream.println(
- """
+ s"""
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -488,7 +493,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| --properties-file FILE Path to a file from which to load extra properties. If not
| specified, this will look for conf/spark-defaults.conf.
|
- | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M).
+ | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M).
| --driver-java-options Extra Java options to pass to the driver.
| --driver-library-path Extra library path entries to pass to the driver.
| --driver-class-path Extra class path entries to pass to the driver. Note that
@@ -539,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
outStream.println("CLI options:")
outStream.println(getSqlShellOptions())
}
+ // scalastyle:on println
SparkSubmit.exitFn(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 43c8a934c311a..79b251e7e62fe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -17,20 +17,17 @@
package org.apache.spark.deploy.client
-import java.util.concurrent.TimeoutException
+import java.util.concurrent._
+import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
-import scala.concurrent.Await
-import scala.concurrent.duration._
-
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
+import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
-import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL,
@@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils}
* @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class AppClient(
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: AppClientListener,
conf: SparkConf)
extends Logging {
- private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
+ private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
- private val REGISTRATION_TIMEOUT = 20.seconds
+ private val REGISTRATION_TIMEOUT_SECONDS = 20
private val REGISTRATION_RETRIES = 3
- private var masterAddress: Address = null
- private var actor: ActorRef = null
+ private var endpoint: RpcEndpointRef = null
private var appId: String = null
- private var registered = false
- private var activeMasterUrl: String = null
+ @volatile private var registered = false
+
+ private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint
+ with Logging {
+
+ private var master: Option[RpcEndpointRef] = None
+ // To avoid calling listener.disconnected() multiple times
+ private var alreadyDisconnected = false
+ @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times
+ @volatile private var registerMasterFutures: Array[JFuture[_]] = null
+ @volatile private var registrationRetryTimer: JScheduledFuture[_] = null
+
+ // A thread pool for registering with masters. Because registering with a master is a blocking
+ // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
+ // time so that we can register with all masters.
+ private val registerMasterThreadPool = new ThreadPoolExecutor(
+ 0,
+ masterRpcAddresses.size, // Make sure we can register with all masters at the same time
+ 60L, TimeUnit.SECONDS,
+ new SynchronousQueue[Runnable](),
+ ThreadUtils.namedThreadFactory("appclient-register-master-threadpool"))
- private class ClientActor extends Actor with ActorLogReceive with Logging {
- var master: ActorSelection = null
- var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
- var alreadyDead = false // To avoid calling listener.dead() multiple times
- var registrationRetryTimer: Option[Cancellable] = None
+ // A scheduled executor for scheduling the registration actions
+ private val registrationRetryThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread")
- override def preStart() {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ override def onStart(): Unit = {
try {
- registerWithMaster()
+ registerWithMaster(1)
} catch {
case e: Exception =>
logWarning("Failed to connect to master", e)
markDisconnected()
- context.stop(self)
+ stop()
}
}
- def tryRegisterAllMasters() {
- for (masterAkkaUrl <- masterAkkaUrls) {
- logInfo("Connecting to master " + masterAkkaUrl + "...")
- val actor = context.actorSelection(masterAkkaUrl)
- actor ! RegisterApplication(appDescription)
+ /**
+ * Register with all masters asynchronously and returns an array `Future`s for cancellation.
+ */
+ private def tryRegisterAllMasters(): Array[JFuture[_]] = {
+ for (masterAddress <- masterRpcAddresses) yield {
+ registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = try {
+ if (registered) {
+ return
+ }
+ logInfo("Connecting to master " + masterAddress.toSparkURL + "...")
+ val masterRef =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterRef.send(RegisterApplication(appDescription, self))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ })
}
}
- def registerWithMaster() {
- tryRegisterAllMasters()
- import context.dispatcher
- var retries = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ /**
+ * Register with all masters asynchronously. It will call `registerWithMaster` every
+ * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times.
+ * Once we connect to a master successfully, all scheduling work and Futures will be cancelled.
+ *
+ * nthRetry means this is the nth attempt to register with master.
+ */
+ private def registerWithMaster(nthRetry: Int) {
+ registerMasterFutures = tryRegisterAllMasters()
+ registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = {
Utils.tryOrExit {
- retries += 1
if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- } else if (retries >= REGISTRATION_RETRIES) {
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterThreadPool.shutdownNow()
+ } else if (nthRetry >= REGISTRATION_RETRIES) {
markDead("All masters are unresponsive! Giving up.")
} else {
- tryRegisterAllMasters()
+ registerMasterFutures.foreach(_.cancel(true))
+ registerWithMaster(nthRetry + 1)
}
}
}
- }
+ }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
}
- def changeMaster(url: String) {
- // activeMasterUrl is a valid Spark url since we receive it from master.
- activeMasterUrl = url
- master = context.actorSelection(
- Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem)))
- masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem))
+ /**
+ * Send a message to the current master. If we have not yet registered successfully with any
+ * master, the message will be dropped.
+ */
+ private def sendToMaster(message: Any): Unit = {
+ master match {
+ case Some(masterRef) => masterRef.send(message)
+ case None => logWarning(s"Drop $message because has not yet connected to master")
+ }
}
- private def isPossibleMaster(remoteUrl: Address) = {
- masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort)
+ private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = {
+ masterRpcAddresses.contains(remoteAddress)
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisteredApplication(appId_, masterUrl) =>
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisteredApplication(appId_, masterRef) =>
+ // FIXME How to handle the following cases?
+ // 1. A master receives multiple registrations and sends back multiple
+ // RegisteredApplications due to an unstable network.
+ // 2. Receive multiple RegisteredApplication from different masters because the master is
+ // changing.
appId = appId_
registered = true
- changeMaster(masterUrl)
+ master = Some(masterRef)
listener.connected(appId)
case ApplicationRemoved(message) =>
markDead("Master removed our application: %s".format(message))
- context.stop(self)
+ stop()
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort,
cores))
- master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)
+ // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not
+ // guaranteed), `ExecutorStateChanged` may be sent to a dead master.
+ sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None))
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
@@ -142,24 +184,32 @@ private[spark] class AppClient(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
- case MasterChanged(masterUrl, masterWebUiUrl) =>
- logInfo("Master has changed, new master is at " + masterUrl)
- changeMaster(masterUrl)
+ case MasterChanged(masterRef, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
+ master = Some(masterRef)
alreadyDisconnected = false
- sender ! MasterChangeAcknowledged(appId)
+ masterRef.send(MasterChangeAcknowledged(appId))
+ }
- case DisassociatedEvent(_, address, _) if address == masterAddress =>
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case StopAppClient =>
+ markDead("Application has been stopped.")
+ sendToMaster(UnregisterApplication(appId))
+ context.reply(true)
+ stop()
+ }
+
+ override def onDisconnected(address: RpcAddress): Unit = {
+ if (master.exists(_.address == address)) {
logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
+ }
+ }
- case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) =>
+ override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = {
+ if (isPossibleMaster(address)) {
logWarning(s"Could not connect to $address: $cause")
-
- case StopAppClient =>
- markDead("Application has been stopped.")
- master ! UnregisterApplication(appId)
- sender ! true
- context.stop(self)
+ }
}
/**
@@ -179,28 +229,31 @@ private[spark] class AppClient(
}
}
- override def postStop() {
- registrationRetryTimer.foreach(_.cancel())
+ override def onStop(): Unit = {
+ if (registrationRetryTimer != null) {
+ registrationRetryTimer.cancel(true)
+ }
+ registrationRetryThread.shutdownNow()
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterThreadPool.shutdownNow()
}
}
def start() {
// Just launch an actor; it will call back into the listener.
- actor = actorSystem.actorOf(Props(new ClientActor))
+ endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))
}
def stop() {
- if (actor != null) {
+ if (endpoint != null) {
try {
- val timeout = RpcUtils.askTimeout(conf)
- val future = actor.ask(StopAppClient)(timeout)
- Await.result(future, timeout)
+ endpoint.askWithRetry[Boolean](StopAppClient)
} catch {
case e: TimeoutException =>
logInfo("Stop request to Master timed out; it may already be shut down.")
}
- actor = null
+ endpoint = null
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 40835b9550586..1c79089303e3d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,9 +17,10 @@
package org.apache.spark.deploy.client
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.Utils
private[spark] object TestClient {
@@ -46,13 +47,12 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
val conf = new SparkConf
- val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0,
- conf = conf, securityManager = new SecurityManager(conf))
+ val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf))
val desc = new ApplicationDescription("TestClient", Some(1), 512,
Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored")
val listener = new TestListener
- val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf)
+ val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf)
client.start()
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
index c5ac45c6730d3..a98b1fa8f83a1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
@@ -19,7 +19,9 @@ package org.apache.spark.deploy.client
private[spark] object TestExecutor {
def main(args: Array[String]) {
+ // scalastyle:off println
println("Hello world!")
+ // scalastyle:on println
while (true) {
Thread.sleep(1000)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 5427a88f32ffd..2cc465e55fceb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -83,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
// List of application logs to be deleted by event log cleaner.
private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo]
- // Constants used to parse Spark 1.0.0 log directories.
- private[history] val LOG_PREFIX = "EVENT_LOG_"
- private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_"
- private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_"
- private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
-
/**
* Return a runnable that performs the given operation on the event logs.
* This operation is expected to be executed periodically.
@@ -146,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = {
try {
applications.get(appId).flatMap { appInfo =>
- appInfo.attempts.find(_.attemptId == attemptId).map { attempt =>
+ appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt =>
val replayBus = new ReplayListenerBus()
val ui = {
val conf = this.conf.clone()
@@ -155,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime)
// Do not call ui.bind() to avoid creating a new server for each application
}
-
val appListener = new ApplicationEventListener()
replayBus.addListener(appListener)
val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus)
-
- ui.setAppName(s"${appInfo.name} ($appId)")
-
- val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
- ui.getSecurityManager.setAcls(uiAclsEnabled)
- // make sure to set admin acls before view acls so they are properly picked up
- ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse(""))
- ui.getSecurityManager.setViewAcls(attempt.sparkUser,
- appListener.viewAcls.getOrElse(""))
- ui
+ appInfo.map { info =>
+ ui.setAppName(s"${info.name} ($appId)")
+
+ val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
+ ui.getSecurityManager.setAcls(uiAclsEnabled)
+ // make sure to set admin acls before view acls so they are properly picked up
+ ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse(""))
+ ui.getSecurityManager.setViewAcls(attempt.sparkUser,
+ appListener.viewAcls.getOrElse(""))
+ ui
+ }
}
}
} catch {
@@ -282,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val newAttempts = logs.flatMap { fileStatus =>
try {
val res = replay(fileStatus, bus)
- logInfo(s"Application log ${res.logPath} loaded successfully.")
- Some(res)
+ res match {
+ case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.")
+ case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " +
+ "The application may have not started.")
+ }
+ res
} catch {
case e: Exception =>
logError(
@@ -429,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
/**
* Replays the events in the specified log file and returns information about the associated
- * application.
+ * application. Return `None` if the application ID cannot be located.
*/
- private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = {
+ private def replay(
+ eventLog: FileStatus,
+ bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = {
val logPath = eventLog.getPath()
logInfo(s"Replaying log path: $logPath")
val logInput =
@@ -445,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val appCompleted = isApplicationCompleted(eventLog)
bus.addListener(appListener)
bus.replay(logInput, logPath.toString, !appCompleted)
- new FsApplicationAttemptInfo(
- logPath.getName(),
- appListener.appName.getOrElse(NOT_STARTED),
- appListener.appId.getOrElse(logPath.getName()),
- appListener.appAttemptId,
- appListener.startTime.getOrElse(-1L),
- appListener.endTime.getOrElse(-1L),
- getModificationTime(eventLog).get,
- appListener.sparkUser.getOrElse(NOT_STARTED),
- appCompleted)
+
+ // Without an app ID, new logs will render incorrectly in the listing page, so do not list or
+ // try to show their UI. Some old versions of Spark generate logs without an app ID, so let
+ // logs generated by those versions go through.
+ if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) {
+ Some(new FsApplicationAttemptInfo(
+ logPath.getName(),
+ appListener.appName.getOrElse(NOT_STARTED),
+ appListener.appId.getOrElse(logPath.getName()),
+ appListener.appAttemptId,
+ appListener.startTime.getOrElse(-1L),
+ appListener.endTime.getOrElse(-1L),
+ getModificationTime(eventLog).get,
+ appListener.sparkUser.getOrElse(NOT_STARTED),
+ appCompleted))
+ } else {
+ None
+ }
} finally {
logInput.close()
}
@@ -529,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
}
}
+ /**
+ * Returns whether the version of Spark that generated logs records app IDs. App IDs were added
+ * in Spark 1.1.
+ */
+ private def sparkVersionHasAppId(entry: FileStatus): Boolean = {
+ if (isLegacyLogDirectory(entry)) {
+ fs.listStatus(entry.getPath())
+ .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) }
+ .map { status =>
+ val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length())
+ version != "1.0" && version != "1.1"
+ }
+ .getOrElse(true)
+ } else {
+ true
+ }
+ }
+
}
-private object FsHistoryProvider {
+private[history] object FsHistoryProvider {
val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
+
+ // Constants used to parse Spark 1.0.0 log directories.
+ val LOG_PREFIX = "EVENT_LOG_"
+ val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_"
+ val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_"
+ val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
}
private class FsApplicationAttemptInfo(
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 4692d22651c93..18265df9faa2c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
Utils.loadDefaultSparkProperties(conf, propertiesFile)
private def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"""
|Usage: HistoryServer [options]
@@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
| spark.history.fs.updateInterval How often to reload log data from storage
| (in seconds, default: 10)
|""".stripMargin)
+ // scalastyle:on println
System.exit(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 1620e95bea218..aa54ed9360f36 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -22,10 +22,9 @@ import java.util.Date
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import akka.actor.ActorRef
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
@@ -33,7 +32,7 @@ private[spark] class ApplicationInfo(
val id: String,
val desc: ApplicationDescription,
val submitDate: Date,
- val driver: ActorRef,
+ val driver: RpcEndpointRef,
defaultCores: Int)
extends Serializable {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index fccceb3ea528b..48070768f6edb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -21,20 +21,18 @@ import java.io.FileNotFoundException
import java.net.URLEncoder
import java.text.SimpleDateFormat
import java.util.Date
+import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.concurrent.Await
-import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.Serialization
import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path
+import org.apache.spark.rpc.akka.AkkaRpcEnv
+import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
ExecutorState, SparkHadoopUtil}
@@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
private[master] class Master(
- host: String,
- port: Int,
+ override val rpcEnv: RpcEnv,
+ address: RpcAddress,
webUiPort: Int,
val securityMgr: SecurityManager,
val conf: SparkConf)
- extends Actor with ActorLogReceive with Logging with LeaderElectable {
+ extends ThreadSafeRpcEndpoint with Logging with LeaderElectable {
- import context.dispatcher // to use Akka's scheduler.schedule()
+ private val forwardMessageThread =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
+
+ // TODO Remove it once we don't use akka.serialization.Serialization
+ private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
- private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
- private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
+ private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000
private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
@@ -75,10 +77,10 @@ private[master] class Master(
val apps = new HashSet[ApplicationInfo]
private val idToWorker = new HashMap[String, WorkerInfo]
- private val addressToWorker = new HashMap[Address, WorkerInfo]
+ private val addressToWorker = new HashMap[RpcAddress, WorkerInfo]
- private val actorToApp = new HashMap[ActorRef, ApplicationInfo]
- private val addressToApp = new HashMap[Address, ApplicationInfo]
+ private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo]
+ private val addressToApp = new HashMap[RpcAddress, ApplicationInfo]
private val completedApps = new ArrayBuffer[ApplicationInfo]
private var nextAppNumber = 0
private val appIdToUI = new HashMap[String, SparkUI]
@@ -89,21 +91,22 @@ private[master] class Master(
private val waitingDrivers = new ArrayBuffer[DriverInfo]
private var nextDriverNumber = 0
- Utils.checkHost(host, "Expected hostname")
+ Utils.checkHost(address.host, "Expected hostname")
private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
securityMgr)
private val masterSource = new MasterSource(this)
- private val webUi = new MasterWebUI(this, webUiPort)
+ // After onStart, webUi will be set
+ private var webUi: MasterWebUI = null
private val masterPublicAddress = {
val envVar = conf.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else host
+ if (envVar != null) envVar else address.host
}
- private val masterUrl = "spark://" + host + ":" + port
+ private val masterUrl = address.toSparkURL
private var masterWebUiUrl: String = _
private var state = RecoveryState.STANDBY
@@ -112,7 +115,9 @@ private[master] class Master(
private var leaderElectionAgent: LeaderElectionAgent = _
- private var recoveryCompletionTask: Cancellable = _
+ private var recoveryCompletionTask: ScheduledFuture[_] = _
+
+ private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
@@ -130,20 +135,23 @@ private[master] class Master(
private val restServer =
if (restServerEnabled) {
val port = conf.getInt("spark.master.rest.port", 6066)
- Some(new StandaloneRestServer(host, port, conf, self, masterUrl))
+ Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl))
} else {
None
}
private val restServerBoundPort = restServer.map(_.start())
- override def preStart() {
+ override def onStart(): Unit = {
logInfo("Starting Spark master at " + masterUrl)
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
- // Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ webUi = new MasterWebUI(this, webUiPort)
webUi.bind()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
- context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
+ checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(CheckForWorkerTimeOut)
+ }
+ }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
@@ -157,16 +165,16 @@ private[master] class Master(
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
- new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system))
+ new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
- new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system))
+ new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
- .newInstance(conf, SerializationExtension(context.system))
+ .newInstance(conf, SerializationExtension(actorSystem))
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
@@ -176,18 +184,17 @@ private[master] class Master(
leaderElectionAgent = leaderElectionAgent_
}
- override def preRestart(reason: Throwable, message: Option[Any]) {
- super.preRestart(reason, message) // calls postStop()!
- logError("Master actor restarted due to exception", reason)
- }
-
- override def postStop() {
+ override def onStop() {
masterMetricsSystem.report()
applicationMetricsSystem.report()
// prevent the CompleteRecovery message sending to restarted master
if (recoveryCompletionTask != null) {
- recoveryCompletionTask.cancel()
+ recoveryCompletionTask.cancel(true)
}
+ if (checkForWorkerTimeOutTask != null) {
+ checkForWorkerTimeOutTask.cancel(true)
+ }
+ forwardMessageThread.shutdownNow()
webUi.stop()
restServer.foreach(_.stop())
masterMetricsSystem.stop()
@@ -197,14 +204,14 @@ private[master] class Master(
}
override def electedLeader() {
- self ! ElectedLeader
+ self.send(ElectedLeader)
}
override def revokedLeadership() {
- self ! RevokedLeadership
+ self.send(RevokedLeadership)
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
@@ -215,8 +222,11 @@ private[master] class Master(
logInfo("I have been elected leader! New state: " + state)
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedDrivers, storedWorkers)
- recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self,
- CompleteRecovery)
+ recoveryCompletionTask = forwardMessageThread.schedule(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(CompleteRecovery)
+ }
+ }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
}
@@ -227,111 +237,42 @@ private[master] class Master(
System.exit(0)
}
- case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) =>
- {
+ case RegisterWorker(
+ id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else if (idToWorker.contains(id)) {
- sender ! RegisterWorkerFailed("Duplicate worker ID")
+ workerRef.send(RegisterWorkerFailed("Duplicate worker ID"))
} else {
val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
- sender, workerUiPort, publicAddress)
+ workerRef, workerUiPort, publicAddress)
if (registerWorker(worker)) {
persistenceEngine.addWorker(worker)
- sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
+ workerRef.send(RegisteredWorker(self, masterWebUiUrl))
schedule()
} else {
- val workerAddress = worker.actor.path.address
+ val workerAddress = worker.endpoint.address
logWarning("Worker registration failed. Attempted to re-register worker at same " +
"address: " + workerAddress)
- sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: "
- + workerAddress)
- }
- }
- }
-
- case RequestSubmitDriver(description) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- "Can only accept driver submissions in ALIVE state."
- sender ! SubmitDriverResponse(false, None, msg)
- } else {
- logInfo("Driver submitted " + description.command.mainClass)
- val driver = createDriver(description)
- persistenceEngine.addDriver(driver)
- waitingDrivers += driver
- drivers.add(driver)
- schedule()
-
- // TODO: It might be good to instead have the submission client poll the master to determine
- // the current status of the driver. For now it's simply "fire and forget".
-
- sender ! SubmitDriverResponse(true, Some(driver.id),
- s"Driver successfully submitted as ${driver.id}")
- }
- }
-
- case RequestKillDriver(driverId) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- s"Can only kill drivers in ALIVE state."
- sender ! KillDriverResponse(driverId, success = false, msg)
- } else {
- logInfo("Asked to kill driver " + driverId)
- val driver = drivers.find(_.id == driverId)
- driver match {
- case Some(d) =>
- if (waitingDrivers.contains(d)) {
- waitingDrivers -= d
- self ! DriverStateChanged(driverId, DriverState.KILLED, None)
- } else {
- // We just notify the worker to kill the driver here. The final bookkeeping occurs
- // on the return path when the worker submits a state change back to the master
- // to notify it that the driver was successfully killed.
- d.worker.foreach { w =>
- w.actor ! KillDriver(driverId)
- }
- }
- // TODO: It would be nice for this to be a synchronous response
- val msg = s"Kill request for $driverId submitted"
- logInfo(msg)
- sender ! KillDriverResponse(driverId, success = true, msg)
- case None =>
- val msg = s"Driver $driverId has already finished or does not exist"
- logWarning(msg)
- sender ! KillDriverResponse(driverId, success = false, msg)
- }
- }
- }
-
- case RequestDriverStatus(driverId) => {
- if (state != RecoveryState.ALIVE) {
- val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
- "Can only request driver status in ALIVE state."
- sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))
- } else {
- (drivers ++ completedDrivers).find(_.id == driverId) match {
- case Some(driver) =>
- sender ! DriverStatusResponse(found = true, Some(driver.state),
- driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)
- case None =>
- sender ! DriverStatusResponse(found = false, None, None, None, None)
+ workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: "
+ + workerAddress))
}
}
}
- case RegisterApplication(description) => {
+ case RegisterApplication(description, driver) => {
+ // TODO Prevent repeated registrations from some driver
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else {
logInfo("Registering app " + description.name)
- val app = createApplication(description, sender)
+ val app = createApplication(description, driver)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
persistenceEngine.addApplication(app)
- sender ! RegisteredApplication(app.id, masterUrl)
+ driver.send(RegisteredApplication(app.id, self))
schedule()
}
}
@@ -343,7 +284,7 @@ private[master] class Master(
val appInfo = idToApp(appId)
exec.state = state
if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() }
- exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus))
if (ExecutorState.isFinished(state)) {
// Remove this executor from the worker and app
logInfo(s"Removing executor ${exec.fullId} because it is $state")
@@ -384,7 +325,7 @@ private[master] class Master(
}
}
- case Heartbeat(workerId) => {
+ case Heartbeat(workerId, worker) => {
idToWorker.get(workerId) match {
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
@@ -392,7 +333,7 @@ private[master] class Master(
if (workers.map(_.id).contains(workerId)) {
logWarning(s"Got heartbeat from unregistered worker $workerId." +
" Asking it to re-register.")
- sender ! ReconnectWorker(masterUrl)
+ worker.send(ReconnectWorker(masterUrl))
} else {
logWarning(s"Got heartbeat from unregistered worker $workerId." +
" This worker was never registered, so ignoring the heartbeat.")
@@ -444,30 +385,103 @@ private[master] class Master(
logInfo(s"Received unregister request from application $applicationId")
idToApp.get(applicationId).foreach(finishApplication)
- case DisassociatedEvent(_, address, _) => {
- // The disconnected client could've been either a worker or an app; remove whichever it was
- logInfo(s"$address got disassociated, removing it.")
- addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
+ case CheckForWorkerTimeOut => {
+ timeOutDeadWorkers()
}
+ }
- case RequestMasterState => {
- sender ! MasterStateResponse(
- host, port, restServerBoundPort,
- workers.toArray, apps.toArray, completedApps.toArray,
- drivers.toArray, completedDrivers.toArray, state)
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestSubmitDriver(description) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ "Can only accept driver submissions in ALIVE state."
+ context.reply(SubmitDriverResponse(self, false, None, msg))
+ } else {
+ logInfo("Driver submitted " + description.command.mainClass)
+ val driver = createDriver(description)
+ persistenceEngine.addDriver(driver)
+ waitingDrivers += driver
+ drivers.add(driver)
+ schedule()
+
+ // TODO: It might be good to instead have the submission client poll the master to determine
+ // the current status of the driver. For now it's simply "fire and forget".
+
+ context.reply(SubmitDriverResponse(self, true, Some(driver.id),
+ s"Driver successfully submitted as ${driver.id}"))
+ }
}
- case CheckForWorkerTimeOut => {
- timeOutDeadWorkers()
+ case RequestKillDriver(driverId) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ s"Can only kill drivers in ALIVE state."
+ context.reply(KillDriverResponse(self, driverId, success = false, msg))
+ } else {
+ logInfo("Asked to kill driver " + driverId)
+ val driver = drivers.find(_.id == driverId)
+ driver match {
+ case Some(d) =>
+ if (waitingDrivers.contains(d)) {
+ waitingDrivers -= d
+ self.send(DriverStateChanged(driverId, DriverState.KILLED, None))
+ } else {
+ // We just notify the worker to kill the driver here. The final bookkeeping occurs
+ // on the return path when the worker submits a state change back to the master
+ // to notify it that the driver was successfully killed.
+ d.worker.foreach { w =>
+ w.endpoint.send(KillDriver(driverId))
+ }
+ }
+ // TODO: It would be nice for this to be a synchronous response
+ val msg = s"Kill request for $driverId submitted"
+ logInfo(msg)
+ context.reply(KillDriverResponse(self, driverId, success = true, msg))
+ case None =>
+ val msg = s"Driver $driverId has already finished or does not exist"
+ logWarning(msg)
+ context.reply(KillDriverResponse(self, driverId, success = false, msg))
+ }
+ }
+ }
+
+ case RequestDriverStatus(driverId) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
+ "Can only request driver status in ALIVE state."
+ context.reply(
+ DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))))
+ } else {
+ (drivers ++ completedDrivers).find(_.id == driverId) match {
+ case Some(driver) =>
+ context.reply(DriverStatusResponse(found = true, Some(driver.state),
+ driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception))
+ case None =>
+ context.reply(DriverStatusResponse(found = false, None, None, None, None))
+ }
+ }
+ }
+
+ case RequestMasterState => {
+ context.reply(MasterStateResponse(
+ address.host, address.port, restServerBoundPort,
+ workers.toArray, apps.toArray, completedApps.toArray,
+ drivers.toArray, completedDrivers.toArray, state))
}
case BoundPortsRequest => {
- sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort)
+ context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort))
}
}
+ override def onDisconnected(address: RpcAddress): Unit = {
+ // The disconnected client could've been either a worker or an app; remove whichever it was
+ logInfo(s"$address got disassociated, removing it.")
+ addressToWorker.get(address).foreach(removeWorker)
+ addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
+ }
+
private def canCompleteRecovery =
workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
apps.count(_.state == ApplicationState.UNKNOWN) == 0
@@ -479,7 +493,7 @@ private[master] class Master(
try {
registerApplication(app)
app.state = ApplicationState.UNKNOWN
- app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
+ app.driver.send(MasterChanged(self, masterWebUiUrl))
} catch {
case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
}
@@ -496,7 +510,7 @@ private[master] class Master(
try {
registerWorker(worker)
worker.state = WorkerState.UNKNOWN
- worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
+ worker.endpoint.send(MasterChanged(self, masterWebUiUrl))
} catch {
case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
}
@@ -505,10 +519,8 @@ private[master] class Master(
private def completeRecovery() {
// Ensure "only-once" recovery semantics using a short synchronization period.
- synchronized {
- if (state != RecoveryState.RECOVERING) { return }
- state = RecoveryState.COMPLETING_RECOVERY
- }
+ if (state != RecoveryState.RECOVERING) { return }
+ state = RecoveryState.COMPLETING_RECOVERY
// Kill off any workers and apps that didn't respond to us.
workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
@@ -623,10 +635,10 @@ private[master] class Master(
private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(masterUrl,
- exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)
- exec.application.driver ! ExecutorAdded(
- exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
+ worker.endpoint.send(LaunchExecutor(masterUrl,
+ exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory))
+ exec.application.driver.send(ExecutorAdded(
+ exec.id, worker.id, worker.hostPort, exec.cores, exec.memory))
}
private def registerWorker(worker: WorkerInfo): Boolean = {
@@ -638,7 +650,7 @@ private[master] class Master(
workers -= w
}
- val workerAddress = worker.actor.path.address
+ val workerAddress = worker.endpoint.address
if (addressToWorker.contains(workerAddress)) {
val oldWorker = addressToWorker(workerAddress)
if (oldWorker.state == WorkerState.UNKNOWN) {
@@ -661,11 +673,11 @@ private[master] class Master(
logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
- addressToWorker -= worker.actor.path.address
+ addressToWorker -= worker.endpoint.address
for (exec <- worker.executors.values) {
logInfo("Telling app of lost executor: " + exec.id)
- exec.application.driver ! ExecutorUpdated(
- exec.id, ExecutorState.LOST, Some("worker lost"), None)
+ exec.application.driver.send(ExecutorUpdated(
+ exec.id, ExecutorState.LOST, Some("worker lost"), None))
exec.application.removeExecutor(exec)
}
for (driver <- worker.drivers.values) {
@@ -687,14 +699,15 @@ private[master] class Master(
schedule()
}
- private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef):
+ ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores)
}
private def registerApplication(app: ApplicationInfo): Unit = {
- val appAddress = app.driver.path.address
+ val appAddress = app.driver.address
if (addressToApp.contains(appAddress)) {
logInfo("Attempted to re-register application at same address: " + appAddress)
return
@@ -703,7 +716,7 @@ private[master] class Master(
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
- actorToApp(app.driver) = app
+ endpointToApp(app.driver) = app
addressToApp(appAddress) = app
waitingApps += app
}
@@ -717,8 +730,8 @@ private[master] class Master(
logInfo("Removing app " + app.id)
apps -= app
idToApp -= app.id
- actorToApp -= app.driver
- addressToApp -= app.driver.path.address
+ endpointToApp -= app.driver
+ addressToApp -= app.driver.address
if (completedApps.size >= RETAINED_APPLICATIONS) {
val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
completedApps.take(toRemove).foreach( a => {
@@ -735,19 +748,19 @@ private[master] class Master(
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
+ exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id))
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
- app.driver ! ApplicationRemoved(state.toString)
+ app.driver.send(ApplicationRemoved(state.toString))
}
persistenceEngine.removeApplication(app)
schedule()
// Tell all workers that the application has finished, so they can clean up any app state.
workers.foreach { w =>
- w.actor ! ApplicationFinished(app.id)
+ w.endpoint.send(ApplicationFinished(app.id))
}
}
}
@@ -768,7 +781,7 @@ private[master] class Master(
}
val eventLogFilePrefix = EventLoggingListener.getLogPath(
- eventLogDir, app.id, None, app.desc.eventLogCodec)
+ eventLogDir, app.id, app.desc.eventLogCodec)
val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf)
val inProgressExists = fs.exists(new Path(eventLogFilePrefix +
EventLoggingListener.IN_PROGRESS))
@@ -832,14 +845,14 @@ private[master] class Master(
private def timeOutDeadWorkers() {
// Copy the workers into an array so we don't modify the hashset while iterating through it
val currentTime = System.currentTimeMillis()
- val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray
+ val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray
for (worker <- toRemove) {
if (worker.state != WorkerState.DEAD) {
logWarning("Removing %s because we got no heartbeat in %d seconds".format(
- worker.id, WORKER_TIMEOUT/1000))
+ worker.id, WORKER_TIMEOUT_MS / 1000))
removeWorker(worker)
} else {
- if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) {
+ if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) {
workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
}
}
@@ -862,7 +875,7 @@ private[master] class Master(
logInfo("Launching driver " + driver.id + " on worker " + worker.id)
worker.addDriver(driver)
driver.worker = Some(worker)
- worker.actor ! LaunchDriver(driver.id, driver.desc)
+ worker.endpoint.send(LaunchDriver(driver.id, driver.desc))
driver.state = DriverState.RUNNING
}
@@ -891,57 +904,33 @@ private[master] class Master(
}
private[deploy] object Master extends Logging {
- val systemName = "sparkMaster"
- private val actorName = "Master"
+ val SYSTEM_NAME = "sparkMaster"
+ val ENDPOINT_NAME = "Master"
def main(argStrings: Array[String]) {
SignalLogger.register(log)
val conf = new SparkConf
val args = new MasterArguments(argStrings, conf)
- val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
- actorSystem.awaitTermination()
- }
-
- /**
- * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`.
- *
- * @throws SparkException if the url is invalid
- */
- def toAkkaUrl(sparkUrl: String, protocol: String): String = {
- val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
- AkkaUtils.address(protocol, systemName, host, port, actorName)
- }
-
- /**
- * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`.
- *
- * @throws SparkException if the url is invalid
- */
- def toAkkaAddress(sparkUrl: String, protocol: String): Address = {
- val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
- Address(protocol, systemName, host, port)
+ val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
+ rpcEnv.awaitTermination()
}
/**
- * Start the Master and return a four tuple of:
- * (1) The Master actor system
- * (2) The bound port
- * (3) The web UI bound port
- * (4) The REST server bound port, if any
+ * Start the Master and return a three tuple of:
+ * (1) The Master RpcEnv
+ * (2) The web UI bound port
+ * (3) The REST server bound port, if any
*/
- def startSystemAndActor(
+ def startRpcEnvAndEndpoint(
host: String,
port: Int,
webUiPort: Int,
- conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = {
+ conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
val securityMgr = new SecurityManager(conf)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
- securityManager = securityMgr)
- val actor = actorSystem.actorOf(
- Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName)
- val timeout = RpcUtils.askTimeout(conf)
- val portsRequest = actor.ask(BoundPortsRequest)(timeout)
- val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse]
- (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort)
+ val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
+ val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
+ new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
+ val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
+ (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 435b9b12f83b8..44cefbc77f08e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
* Print usage and exit JVM with the given exit code.
*/
private def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"Usage: Master [options]\n" +
"\n" +
@@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
" --webui-port PORT Port for web UI (default: 8080)\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index 15c6296888f70..68c937188b333 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -28,7 +28,7 @@ private[master] object MasterMessages {
case object RevokedLeadership
- // Actor System to Master
+ // Master to itself
case object CheckForWorkerTimeOut
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 9b3d48c6edc84..f751966605206 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -19,9 +19,7 @@ package org.apache.spark.deploy.master
import scala.collection.mutable
-import akka.actor.ActorRef
-
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -30,7 +28,7 @@ private[spark] class WorkerInfo(
val port: Int,
val cores: Int,
val memory: Int,
- val actor: ActorRef,
+ val endpoint: RpcEndpointRef,
val webUiPort: Int,
val publicAddress: String)
extends Serializable {
@@ -107,4 +105,6 @@ private[spark] class WorkerInfo(
def setState(state: WorkerState.Value): Unit = {
this.state = state
}
+
+ def isAlive(): Boolean = this.state == WorkerState.ALIVE
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 52758d6a7c4be..6fdff86f66e01 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -17,10 +17,7 @@
package org.apache.spark.deploy.master
-import akka.actor.ActorRef
-
import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
import org.apache.spark.deploy.SparkCuratorUtil
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 06e265f99e231..e28e7e379ac91 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
-
import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorDesc
@@ -32,14 +29,12 @@ import org.apache.spark.util.Utils
private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {
- private val master = parent.masterActorRef
- private val timeout = parent.timeout
+ private val master = parent.masterEndpointRef
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, timeout)
+ val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 6a7c74020bace..c3e20ebf8d6eb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
-import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master._
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
- private val master = parent.masterActorRef
- private val timeout = parent.timeout
+ private val master = parent.masterEndpointRef
def getMasterState: MasterStateResponse = {
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- Await.result(stateFuture, timeout)
+ master.askWithRetry[MasterStateResponse](RequestMasterState)
}
override def renderJson(request: HttpServletRequest): JValue = {
@@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
def handleDriverKillRequest(request: HttpServletRequest): Unit = {
- handleKillRequest(request, id => { master ! RequestKillDriver(id) })
+ handleKillRequest(request, id => {
+ master.ask[KillDriverResponse](RequestKillDriver(id))
+ })
}
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 2111a8581f2e4..6174fc11f83d8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource
UIRoot}
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.RpcUtils
/**
* Web UI server for the standalone master.
@@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int)
extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging
with UIRoot {
- val masterActorRef = master.self
- val timeout = RpcUtils.askTimeout(master.conf)
+ val masterEndpointRef = master.self
val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true)
val masterPage = new MasterPage(this)
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
index 894cb78d8591a..5accaf78d0a51 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
@@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
case ("--master" | "-m") :: value :: tail =>
if (!value.startsWith("mesos://")) {
+ // scalastyle:off println
System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)")
+ // scalastyle:on println
System.exit(1)
}
masterUrl = value.stripPrefix("mesos://")
@@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
case Nil => {
if (masterUrl == null) {
+ // scalastyle:off println
System.err.println("--master is required")
+ // scalastyle:on println
printUsageAndExit(1)
}
}
@@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
}
private def printUsageAndExit(exitCode: Int): Unit = {
+ // scalastyle:off println
System.err.println(
"Usage: MesosClusterDispatcher [options]\n" +
"\n" +
@@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
" Zookeeper for persistence\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index 502b9bb701ccf..d5b9bcab1423f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest
import java.io.File
import javax.servlet.http.HttpServletResponse
-import akka.actor.ActorRef
import org.apache.spark.deploy.ClientArguments._
import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
-import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.util.Utils
import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
/**
@@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
* @param host the address this server should bind to
* @param requestedPort the port this server will attempt to bind to
* @param masterConf the conf used by the Master
- * @param masterActor reference to the Master actor to which requests can be sent
+ * @param masterEndpoint reference to the Master endpoint to which requests can be sent
* @param masterUrl the URL of the Master new drivers will attempt to connect to
*/
private[deploy] class StandaloneRestServer(
host: String,
requestedPort: Int,
masterConf: SparkConf,
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String)
extends RestSubmissionServer(host, requestedPort, masterConf) {
protected override val submitRequestServlet =
- new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf)
+ new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf)
protected override val killRequestServlet =
- new StandaloneKillRequestServlet(masterActor, masterConf)
+ new StandaloneKillRequestServlet(masterEndpoint, masterConf)
protected override val statusRequestServlet =
- new StandaloneStatusRequestServlet(masterActor, masterConf)
+ new StandaloneStatusRequestServlet(masterEndpoint, masterConf)
}
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
-private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf)
extends KillRequestServlet {
protected def handleKill(submissionId: String): KillSubmissionResponse = {
- val askTimeout = RpcUtils.askTimeout(conf)
- val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
- DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse](
+ DeployMessages.RequestKillDriver(submissionId))
val k = new KillSubmissionResponse
k.serverSparkVersion = sparkVersion
k.message = response.message
@@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
-private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf)
extends StatusRequestServlet {
protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
- val askTimeout = RpcUtils.askTimeout(conf)
- val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
- DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse](
+ DeployMessages.RequestDriverStatus(submissionId))
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
val d = new SubmissionStatusResponse
d.serverSparkVersion = sparkVersion
@@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf:
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
private[rest] class StandaloneSubmitRequestServlet(
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String,
conf: SparkConf)
extends SubmitRequestServlet {
@@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet(
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
requestMessage match {
case submitRequest: CreateSubmissionRequest =>
- val askTimeout = RpcUtils.askTimeout(conf)
val driverDescription = buildDriverDescription(submitRequest)
- val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
- DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
+ val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse](
+ DeployMessages.RequestSubmitDriver(driverDescription))
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
submitResponse.message = response.message
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
index 8198296eeb341..868cc35d06ef3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
@@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet(
extends SubmitRequestServlet {
private val DEFAULT_SUPERVISE = false
- private val DEFAULT_MEMORY = 512 // mb
+ private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb
private val DEFAULT_CORES = 1.0
private val nextDriverNumber = new AtomicLong(0)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 1386055eb8c48..ec51c3d935d8e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -21,7 +21,6 @@ import java.io._
import scala.collection.JavaConversions._
-import akka.actor.ActorRef
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.fs.Path
@@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages.DriverStateChanged
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{Utils, Clock, SystemClock}
/**
@@ -43,7 +43,7 @@ private[deploy] class DriverRunner(
val workDir: File,
val sparkHome: File,
val driverDesc: DriverDescription,
- val worker: ActorRef,
+ val worker: RpcEndpointRef,
val workerUrl: String,
val securityManager: SecurityManager)
extends Logging {
@@ -107,7 +107,7 @@ private[deploy] class DriverRunner(
finalState = Some(state)
- worker ! DriverStateChanged(driverId, state, finalException)
+ worker.send(DriverStateChanged(driverId, state, finalException))
}
}.start()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index d1a12b01e78f7..2d6be3042c905 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -60,7 +60,9 @@ object DriverWrapper {
rpcEnv.shutdown()
case _ =>
+ // scalastyle:off println
System.err.println("Usage: DriverWrapper [options]")
+ // scalastyle:on println
System.exit(-1)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index fff17e1095042..29a5042285578 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -21,10 +21,10 @@ import java.io._
import scala.collection.JavaConversions._
-import akka.actor.ActorRef
import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
@@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner(
val appDesc: ApplicationDescription,
val cores: Int,
val memory: Int,
- val worker: ActorRef,
+ val worker: RpcEndpointRef,
val workerId: String,
val host: String,
val webUiPort: Int,
@@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner(
process.destroy()
exitCode = Some(process.waitFor())
}
- worker ! ExecutorStateChanged(appId, execId, state, message, exitCode)
+ worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
}
/** Stop this executor runner, including killing the process it launched */
@@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner(
val exitCode = process.waitFor()
state = ExecutorState.EXITED
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
+ worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)))
} catch {
case interrupted: InterruptedException => {
logInfo("Runner thread for executor " + fullId + " interrupted")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index ebc6cd76c6afd..82e9578bbcba5 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -21,15 +21,14 @@ import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
import java.util.{UUID, Date}
+import java.util.concurrent._
+import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
-import scala.concurrent.duration._
-import scala.language.postfixOps
+import scala.concurrent.ExecutionContext
import scala.util.Random
-
-import akka.actor._
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import scala.util.control.NonFatal
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
@@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
-/**
- * @param masterAkkaUrls Each url should be a valid akka url.
- */
private[worker] class Worker(
- host: String,
- port: Int,
+ override val rpcEnv: RpcEnv,
webUiPort: Int,
cores: Int,
memory: Int,
- masterAkkaUrls: Array[String],
- actorSystemName: String,
- actorName: String,
+ masterRpcAddresses: Array[RpcAddress],
+ systemName: String,
+ endpointName: String,
workDirPath: String = null,
val conf: SparkConf,
val securityMgr: SecurityManager)
- extends Actor with ActorLogReceive with Logging {
- import context.dispatcher
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ private val host = rpcEnv.address.host
+ private val port = rpcEnv.address.port
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
+ // A scheduled executor used to send messages at the specified time.
+ private val forwordMessageScheduler =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler")
+
+ // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future`
+ // methods.
+ private val cleanupThreadExecutor = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread"))
+
// For worker and executor IDs
private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
-
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
@@ -79,32 +85,26 @@ private[worker] class Worker(
val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits)
randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND
}
- private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 *
- REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
- private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60
- * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
+ private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 *
+ REGISTRATION_RETRY_FUZZ_MULTIPLIER))
+ private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60
+ * REGISTRATION_RETRY_FUZZ_MULTIPLIER))
private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false)
// How often worker will clean up old app folders
private val CLEANUP_INTERVAL_MILLIS =
conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000
// TTL for app folders/data; after TTL expires it will be cleaned up
- private val APP_DATA_RETENTION_SECS =
+ private val APP_DATA_RETENTION_SECONDS =
conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600)
private val testing: Boolean = sys.props.contains("spark.testing")
- private var master: ActorSelection = null
- private var masterAddress: Address = null
+ private var master: Option[RpcEndpointRef] = None
private var activeMasterUrl: String = ""
private[worker] var activeMasterWebUiUrl : String = ""
- private val akkaUrl = AkkaUtils.address(
- AkkaUtils.protocol(context.system),
- actorSystemName,
- host,
- port,
- actorName)
- @volatile private var registered = false
- @volatile private var connected = false
+ private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName)
+ private var registered = false
+ private var connected = false
private val workerId = generateWorkerId()
private val sparkHome =
if (testing) {
@@ -136,7 +136,18 @@ private[worker] class Worker(
private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
private val workerSource = new WorkerSource(this)
- private var registrationRetryTimer: Option[Cancellable] = None
+ private var registerMasterFutures: Array[JFuture[_]] = null
+ private var registrationRetryTimer: Option[JScheduledFuture[_]] = None
+
+ // A thread pool for registering with masters. Because registering with a master is a blocking
+ // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
+ // time so that we can register with all masters.
+ private val registerMasterThreadPool = new ThreadPoolExecutor(
+ 0,
+ masterRpcAddresses.size, // Make sure we can register with all masters at the same time
+ 60L, TimeUnit.SECONDS,
+ new SynchronousQueue[Runnable](),
+ ThreadUtils.namedThreadFactory("worker-register-master-threadpool"))
var coresUsed = 0
var memoryUsed = 0
@@ -162,14 +173,13 @@ private[worker] class Worker(
}
}
- override def preStart() {
+ override def onStart() {
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
logInfo("Spark home: " + sparkHome)
createWorkDir()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
@@ -181,24 +191,32 @@ private[worker] class Worker(
metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
- private def changeMaster(url: String, uiUrl: String) {
+ private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) {
// activeMasterUrl it's a valid Spark url since we receive it from master.
- activeMasterUrl = url
+ activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
- master = context.actorSelection(
- Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system)))
- masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system))
+ master = Some(masterRef)
connected = true
// Cancel any outstanding re-registration attempts because we found a new master
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = None
+ cancelLastRegistrationRetry()
}
- private def tryRegisterAllMasters() {
- for (masterAkkaUrl <- masterAkkaUrls) {
- logInfo("Connecting to master " + masterAkkaUrl + "...")
- val actor = context.actorSelection(masterAkkaUrl)
- actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
+ private def tryRegisterAllMasters(): Array[JFuture[_]] = {
+ masterRpcAddresses.map { masterAddress =>
+ registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = {
+ try {
+ logInfo("Connecting to master " + masterAddress + "...")
+ val masterEndpoint =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterEndpoint.send(RegisterWorker(
+ workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ }
+ })
}
}
@@ -211,8 +229,7 @@ private[worker] class Worker(
Utils.tryOrExit {
connectionAttemptCount += 1
if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = None
+ cancelLastRegistrationRetry()
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
/**
@@ -235,21 +252,48 @@ private[worker] class Worker(
* still not safe if the old master recovers within this interval, but this is a much
* less likely scenario.
*/
- if (master != null) {
- master ! RegisterWorker(
- workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
- } else {
- // We are retrying the initial registration
- tryRegisterAllMasters()
+ master match {
+ case Some(masterRef) =>
+ // registered == false && master != None means we lost the connection to master, so
+ // masterRef cannot be used and we need to recreate it again. Note: we must not set
+ // master to None due to the above comments.
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ }
+ val masterAddress = masterRef.address
+ registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable {
+ override def run(): Unit = {
+ try {
+ logInfo("Connecting to master " + masterAddress + "...")
+ val masterEndpoint =
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+ masterEndpoint.send(RegisterWorker(
+ workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
+ } catch {
+ case ie: InterruptedException => // Cancelled
+ case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
+ }
+ }
+ }))
+ case None =>
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ }
+ // We are retrying the initial registration
+ registerMasterFutures = tryRegisterAllMasters()
}
// We have exceeded the initial registration retry threshold
// All retries from now on should use a higher interval
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
- registrationRetryTimer.foreach(_.cancel())
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
- PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
- }
+ registrationRetryTimer.foreach(_.cancel(true))
+ registrationRetryTimer = Some(
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(ReregisterWithMaster)
+ }
+ }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ TimeUnit.SECONDS))
}
} else {
logError("All masters are unresponsive! Giving up.")
@@ -258,41 +302,67 @@ private[worker] class Worker(
}
}
+ /**
+ * Cancel last registeration retry, or do nothing if no retry
+ */
+ private def cancelLastRegistrationRetry(): Unit = {
+ if (registerMasterFutures != null) {
+ registerMasterFutures.foreach(_.cancel(true))
+ registerMasterFutures = null
+ }
+ registrationRetryTimer.foreach(_.cancel(true))
+ registrationRetryTimer = None
+ }
+
private def registerWithMaster() {
- // DisassociatedEvent may be triggered multiple times, so don't attempt registration
+ // onDisconnected may be triggered multiple times, so don't attempt registration
// if there are outstanding registration attempts scheduled.
registrationRetryTimer match {
case None =>
registered = false
- tryRegisterAllMasters()
+ registerMasterFutures = tryRegisterAllMasters()
connectionAttemptCount = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
- INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
- }
+ registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
+ new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(ReregisterWithMaster)
+ }
+ },
+ INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
+ TimeUnit.SECONDS))
case Some(_) =>
logInfo("Not spawning another attempt to register with the master, since there is an" +
" attempt scheduled already.")
}
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisteredWorker(masterUrl, masterWebUiUrl) =>
- logInfo("Successfully registered with master " + masterUrl)
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisteredWorker(masterRef, masterWebUiUrl) =>
+ logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
registered = true
- changeMaster(masterUrl, masterWebUiUrl)
- context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+ changeMaster(masterRef, masterWebUiUrl)
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(SendHeartbeat)
+ }
+ }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
if (CLEANUP_ENABLED) {
logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
- context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
- CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
+ forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ self.send(WorkDirCleanup)
+ }
+ }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
}
case SendHeartbeat =>
- if (connected) { master ! Heartbeat(workerId) }
+ if (connected) { sendToMaster(Heartbeat(workerId, self)) }
case WorkDirCleanup =>
// Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
+ // Copy ids so that it can be used in the cleanup thread.
+ val appIds = executors.values.map(_.appId).toSet
val cleanupFuture = concurrent.future {
val appDirs = workDir.listFiles()
if (appDirs == null) {
@@ -302,27 +372,27 @@ private[worker] class Worker(
// the directory is used by an application - check that the application is not running
// when cleaning up
val appIdFromDir = dir.getName
- val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
+ val isAppStillRunning = appIds.contains(appIdFromDir)
dir.isDirectory && !isAppStillRunning &&
- !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
+ !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS)
}.foreach { dir =>
logInfo(s"Removing directory: ${dir.getPath}")
Utils.deleteRecursively(dir)
}
- }
+ }(cleanupThreadExecutor)
- cleanupFuture onFailure {
+ cleanupFuture.onFailure {
case e: Throwable =>
logError("App dir cleanup failed: " + e.getMessage, e)
- }
+ }(cleanupThreadExecutor)
- case MasterChanged(masterUrl, masterWebUiUrl) =>
- logInfo("Master has changed, new master is at " + masterUrl)
- changeMaster(masterUrl, masterWebUiUrl)
+ case MasterChanged(masterRef, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
+ changeMaster(masterRef, masterWebUiUrl)
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
- sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)
+ masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq))
case RegisterWorkerFailed(message) =>
if (!registered) {
@@ -369,14 +439,14 @@ private[worker] class Worker(
publicAddress,
sparkHome,
executorDir,
- akkaUrl,
+ workerUri,
conf,
appLocalDirs, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
+ sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
} catch {
case e: Exception => {
logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
@@ -384,14 +454,14 @@ private[worker] class Worker(
executors(appId + "/" + execId).kill()
executors -= appId + "/" + execId
}
- master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
- Some(e.toString), None)
+ sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
+ Some(e.toString), None))
}
}
}
- case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
+ sendToMaster(executorStateChanged)
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
executors.get(fullId) match {
@@ -434,7 +504,7 @@ private[worker] class Worker(
sparkHome,
driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)),
self,
- akkaUrl,
+ workerUri,
securityMgr)
drivers(driverId) = driver
driver.start()
@@ -453,7 +523,7 @@ private[worker] class Worker(
}
}
- case DriverStateChanged(driverId, state, exception) => {
+ case driverStageChanged @ DriverStateChanged(driverId, state, exception) => {
state match {
case DriverState.ERROR =>
logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
@@ -466,23 +536,13 @@ private[worker] class Worker(
case _ =>
logDebug(s"Driver $driverId changed state to $state")
}
- master ! DriverStateChanged(driverId, state, exception)
+ sendToMaster(driverStageChanged)
val driver = drivers.remove(driverId).get
finishedDrivers(driverId) = driver
memoryUsed -= driver.driverDesc.mem
coresUsed -= driver.driverDesc.cores
}
- case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
- logInfo(s"$x Disassociated !")
- masterDisconnected()
-
- case RequestWorkerState =>
- sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, drivers.values.toList,
- finishedDrivers.values.toList, activeMasterUrl, cores, memory,
- coresUsed, memoryUsed, activeMasterWebUiUrl)
-
case ReregisterWithMaster =>
reregisterWithMaster()
@@ -491,6 +551,21 @@ private[worker] class Worker(
maybeCleanupApplication(id)
}
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestWorkerState =>
+ context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList,
+ finishedExecutors.values.toList, drivers.values.toList,
+ finishedDrivers.values.toList, activeMasterUrl, cores, memory,
+ coresUsed, memoryUsed, activeMasterWebUiUrl))
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (master.exists(_.address == remoteAddress)) {
+ logInfo(s"$remoteAddress Disassociated !")
+ masterDisconnected()
+ }
+ }
+
private def masterDisconnected() {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
@@ -510,13 +585,29 @@ private[worker] class Worker(
}
}
+ /**
+ * Send a message to the current master. If we have not yet registered successfully with any
+ * master, the message will be dropped.
+ */
+ private def sendToMaster(message: Any): Unit = {
+ master match {
+ case Some(masterRef) => masterRef.send(message)
+ case None =>
+ logWarning(
+ s"Dropping $message because the connection to master has not yet been established")
+ }
+ }
+
private def generateWorkerId(): String = {
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
- override def postStop() {
+ override def onStop() {
+ cleanupThreadExecutor.shutdownNow()
metricsSystem.report()
- registrationRetryTimer.foreach(_.cancel())
+ cancelLastRegistrationRetry()
+ forwordMessageScheduler.shutdownNow()
+ registerMasterThreadPool.shutdownNow()
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
shuffleService.stop()
@@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging {
SignalLogger.register(log)
val conf = new SparkConf
val args = new WorkerArguments(argStrings, conf)
- val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
+ val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.masters, args.workDir)
- actorSystem.awaitTermination()
+ rpcEnv.awaitTermination()
}
- def startSystemAndActor(
+ def startRpcEnvAndEndpoint(
host: String,
port: Int,
webUiPort: Int,
@@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging {
masterUrls: Array[String],
workDir: String,
workerNumber: Option[Int] = None,
- conf: SparkConf = new SparkConf): (ActorSystem, Int) = {
+ conf: SparkConf = new SparkConf): RpcEnv = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
val securityMgr = new SecurityManager(conf)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf, securityManager = securityMgr)
- val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
- actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
- (actorSystem, boundPort)
+ val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
+ val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
+ rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses,
+ systemName, actorName, workDir, conf, securityMgr))
+ rpcEnv
}
def isUseLocalNodeSSLConfig(cmd: Command): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 9678631da9f6f..e89d076802215 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
* Print usage and exit JVM with the given exit code.
*/
def printUsageAndExit(exitCode: Int) {
+ // scalastyle:off println
System.err.println(
"Usage: Worker [options] \n" +
"\n" +
@@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
" --webui-port PORT Port for web UI (default: 8081)\n" +
" --properties-file FILE Path to a custom Spark properties file.\n" +
" Default is conf/spark-defaults.conf.")
+ // scalastyle:on println
System.exit(exitCode)
}
@@ -160,11 +162,13 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
} catch {
case e: Exception => {
totalMb = 2*1024
+ // scalastyle:off println
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
+ // scalastyle:on println
}
}
// Leave out 1 GB for the operating system, but don't return a negative memory size
- math.max(totalMb - 1024, 512)
+ math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB)
}
def checkWorkerMemory(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 83fb991891a41..fae5640b9a213 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -18,7 +18,6 @@
package org.apache.spark.deploy.worker
import org.apache.spark.Logging
-import org.apache.spark.deploy.DeployMessages.SendHeartbeat
import org.apache.spark.rpc._
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
index 9f9f27d71e1ae..fd905feb97e92 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
@@ -17,10 +17,8 @@
package org.apache.spark.deploy.worker.ui
-import scala.concurrent.Await
import scala.xml.Node
-import akka.pattern.ask
import javax.servlet.http.HttpServletRequest
import org.json4s.JValue
@@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
- private val workerActor = parent.worker.self
- private val timeout = parent.timeout
+ private val workerEndpoint = parent.worker.self
override def renderJson(request: HttpServletRequest): JValue = {
- val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, timeout)
+ val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
JsonProtocol.writeWorkerState(workerState)
}
def render(request: HttpServletRequest): Seq[Node] = {
- val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, timeout)
+ val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs")
val runningExecutors = workerState.executors
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index b3bb5f911dbd7..334a5b10142aa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -38,7 +38,7 @@ class WorkerWebUI(
extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI")
with Logging {
- private[ui] val timeout = RpcUtils.askTimeout(worker.conf)
+ private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf)
initialize()
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index f3a26f54a81fb..fcd76ec52742a 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -66,7 +66,10 @@ private[spark] class CoarseGrainedExecutorBackend(
case Success(msg) => Utils.tryLogNonFatalError {
Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor
}
- case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e)
+ case Failure(e) => {
+ logError(s"Cannot register with driver: $driverUrl", e)
+ System.exit(1)
+ }
}(ThreadUtils.sameThread)
}
@@ -232,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
argv = tail
case Nil =>
case tail =>
+ // scalastyle:off println
System.err.println(s"Unrecognized options: ${tail.mkString(" ")}")
+ // scalastyle:on println
printUsageAndExit()
}
}
@@ -246,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
private def printUsageAndExit() = {
+ // scalastyle:off println
System.err.println(
"""
|"Usage: CoarseGrainedExecutorBackend [options]
@@ -259,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
| --worker-url
| --user-class-path
|""".stripMargin)
+ // scalastyle:on println
System.exit(1)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 8f916e0502ecb..f7ef92bc80f91 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -443,7 +443,7 @@ private[spark] class Executor(
try {
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message)
if (response.reregisterBlockManager) {
- logWarning("Told to re-register on heartbeat")
+ logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
index c219d21fbefa9..532850dd57716 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
/**
@@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat {
}
private[spark] class FixedLengthBinaryInputFormat
- extends FileInputFormat[LongWritable, BytesWritable] {
+ extends FileInputFormat[LongWritable, BytesWritable]
+ with Logging {
private var recordLength = -1
@@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat
recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
}
if (recordLength <= 0) {
- println("record length is less than 0, file cannot be split")
+ logDebug("record length is less than 0, file cannot be split")
false
} else {
true
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
index 67a376102994c..79cb0640c8672 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala
@@ -57,16 +57,6 @@ private[nio] class BlockMessage() {
}
def set(buffer: ByteBuffer) {
- /*
- println()
- println("BlockMessage: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
typ = buffer.getInt()
val idLength = buffer.getInt()
val idBuilder = new StringBuilder(idLength)
@@ -138,18 +128,6 @@ private[nio] class BlockMessage() {
buffers += data
}
- /*
- println()
- println("BlockMessage: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
Message.createBufferMessage(buffers)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
index 7d0806f0c2580..f1c9ea8b64ca3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala
@@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
val newBlockMessages = new ArrayBuffer[BlockMessage]()
val buffer = bufferMessage.buffers(0)
buffer.clear()
- /*
- println()
- println("BlockMessageArray: ")
- while(buffer.remaining > 0) {
- print(buffer.get())
- }
- buffer.rewind()
- println()
- println()
- */
while (buffer.remaining() > 0) {
val size = buffer.getInt()
logDebug("Creating block message of size " + size + " bytes")
@@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
logDebug("Buffer list:")
buffers.foreach((x: ByteBuffer) => logDebug("" + x))
- /*
- println()
- println("BlockMessageArray: ")
- buffers.foreach(b => {
- while(b.remaining > 0) {
- print(b.get())
- }
- b.rewind()
- })
- println()
- println()
- */
Message.createBufferMessage(buffers)
}
}
-private[nio] object BlockMessageArray {
+private[nio] object BlockMessageArray extends Logging {
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
@@ -123,10 +101,10 @@ private[nio] object BlockMessageArray {
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
- println("Block message array created")
+ logDebug("Block message array created")
val bufferMessage = blockMessageArray.toBufferMessage
- println("Converted to buffer message")
+ logDebug("Converted to buffer message")
val totalSize = bufferMessage.size
val newBuffer = ByteBuffer.allocate(totalSize)
@@ -138,10 +116,11 @@ private[nio] object BlockMessageArray {
})
newBuffer.flip
val newBufferMessage = Message.createBufferMessage(newBuffer)
- println("Copied to new buffer message, size = " + newBufferMessage.size)
+ logDebug("Copied to new buffer message, size = " + newBufferMessage.size)
val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
- println("Converted back to block message array")
+ logDebug("Converted back to block message array")
+ // scalastyle:off println
newBlockMessageArray.foreach(blockMessage => {
blockMessage.getType match {
case BlockMessage.TYPE_PUT_BLOCK => {
@@ -154,6 +133,7 @@ private[nio] object BlockMessageArray {
}
}
})
+ // scalastyle:on println
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index c0bca2c4bc994..9143918790381 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager {
val conf = new SparkConf
val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ // scalastyle:off println
println("Received [" + msg + "] from [" + id + "]")
+ // scalastyle:on println
None
})
@@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager {
System.gc()
}
+ // scalastyle:off println
def testSequentialSending(manager: ConnectionManager) {
println("--------------------------")
println("Sequential Sending")
@@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager {
println()
}
}
+ // scalastyle:on println
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 33e6998b2cb10..e17bd47905d7a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}
-private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
+private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
@@ -37,9 +37,11 @@ private[spark]
class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
- val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
+ private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
- @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+ @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+
+ override def getCheckpointFile: Option[String] = Some(checkpointPath)
override def getPartitions: Array[Partition] = {
val cpath = new Path(checkpointPath)
@@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
- checkpointData = Some(new RDDCheckpointData[T](this))
- checkpointData.get.cpFile = Some(checkpointPath)
-
override def getPreferredLocations(split: Partition): Seq[String] = {
val status = fs.getFileStatus(new Path(checkpointPath,
CheckpointRDD.splitIdToFile(split.index)))
@@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
CheckpointRDD.readFromFile(file, broadcastedConf, context)
}
- override def checkpoint() {
- // Do nothing. CheckpointRDD should not be checkpointed.
- }
+ // CheckpointRDD should not be checkpointed again
+ override def checkpoint(): Unit = { }
+ override def doCheckpoint(): Unit = { }
}
private[spark] object CheckpointRDD extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index dc60d48927624..defdabf95ac4b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag](
new Thread("stderr reader for " + command) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ // scalastyle:off println
System.err.println(line)
+ // scalastyle:on println
}
}
}.start()
@@ -133,6 +135,7 @@ private[spark] class PipedRDD[T: ClassTag](
override def run() {
val out = new PrintWriter(proc.getOutputStream)
+ // scalastyle:off println
// input the pipe context firstly
if (printPipeContext != null) {
printPipeContext(out.println(_))
@@ -144,6 +147,7 @@ private[spark] class PipedRDD[T: ClassTag](
out.println(elem)
}
}
+ // scalastyle:on println
out.close()
}
}.start()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 10610f4b6f1ff..9f7ebae3e9af3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag](
@transient private var partitions_ : Array[Partition] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
- private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
+ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
/**
* Get the list of dependencies of this RDD, taking into account whether the
@@ -890,6 +890,10 @@ abstract class RDD[T: ClassTag](
* Return an iterator that contains all of the elements in this RDD.
*
* The iterator will consume as much memory as the largest partition in this RDD.
+ *
+ * Note: this results in multiple Spark jobs, and if the input RDD is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input RDD should be cached first.
*/
def toLocalIterator: Iterator[T] = withScope {
def collectPartition(p: Int): Array[T] = {
@@ -1447,12 +1451,16 @@ abstract class RDD[T: ClassTag](
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint() {
+ def checkpoint(): Unit = {
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
- checkpointData = Some(new RDDCheckpointData(this))
- checkpointData.get.markForCheckpoint()
+ // NOTE: we use a global lock here due to complexities downstream with ensuring
+ // children RDD partitions point to the correct parent partitions. In the future
+ // we should revisit this consideration.
+ RDDCheckpointData.synchronized {
+ checkpointData = Some(new RDDCheckpointData(this))
+ }
}
}
@@ -1493,7 +1501,7 @@ abstract class RDD[T: ClassTag](
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
/** Returns the first parent RDD */
- protected[spark] def firstParent[U: ClassTag] = {
+ protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index acbd31aacdf59..4f954363bed8e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -22,16 +22,15 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
import org.apache.spark._
-import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
import org.apache.spark.util.SerializableConfiguration
/**
* Enumeration to manage state transitions of an RDD through checkpointing
- * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ * [ Initialized --> checkpointing in progress --> checkpointed ].
*/
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
- val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+ val Initialized, CheckpointingInProgress, Checkpointed = Value
}
/**
@@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
import CheckpointState._
// The checkpoint state of the associated RDD.
- var cpState = Initialized
+ private var cpState = Initialized
// The file to which the associated RDD has been checkpointed to
- @transient var cpFile: Option[String] = None
+ private var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- var cpRDD: Option[RDD[T]] = None
+ // This is defined if and only if `cpState` is `Checkpointed`.
+ private var cpRDD: Option[CheckpointRDD[T]] = None
- // Mark the RDD for checkpointing
- def markForCheckpoint() {
- RDDCheckpointData.synchronized {
- if (cpState == Initialized) cpState = MarkedForCheckpoint
- }
- }
+ // TODO: are we sure we need to use a global lock in the following methods?
// Is the RDD already checkpointed
- def isCheckpointed: Boolean = {
- RDDCheckpointData.synchronized { cpState == Checkpointed }
+ def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
+ cpState == Checkpointed
}
// Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile: Option[String] = {
- RDDCheckpointData.synchronized { cpFile }
+ def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
+ cpFile
}
- // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
- def doCheckpoint() {
- // If it is marked for checkpointing AND checkpointing is not already in progress,
- // then set it to be in progress, else return
+ /**
+ * Materialize this RDD and write its content to a reliable DFS.
+ * This is called immediately after the first action invoked on this RDD has completed.
+ */
+ def doCheckpoint(): Unit = {
+
+ // Guard against multiple threads checkpointing the same RDD by
+ // atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
- if (cpState == MarkedForCheckpoint) {
+ if (cpState == Initialized) {
cpState = CheckpointingInProgress
} else {
return
@@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
- throw new SparkException("Failed to create checkpoint path " + path)
+ throw new SparkException(s"Failed to create checkpoint path $path")
}
// Save to file, and reload it as an RDD
@@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
+
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
@@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
}
- logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
- }
-
- // Get preferred location of a split after checkpointing
- def getPreferredLocations(split: Partition): Seq[String] = {
- RDDCheckpointData.synchronized {
- cpRDD.get.preferredLocations(split)
- }
+ logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
}
- def getPartitions: Array[Partition] = {
- RDDCheckpointData.synchronized {
- cpRDD.get.partitions
- }
+ def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
+ cpRDD.get.partitions
}
- def checkpointRDD: Option[RDD[T]] = {
- RDDCheckpointData.synchronized {
- cpRDD
- }
+ def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
+ cpRDD
}
}
private[spark] object RDDCheckpointData {
+
+ /** Return the path of the directory to which this RDD's checkpoint data is written. */
def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
- sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
+ sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
}
+ /** Clean up the files associated with the checkpoint data for this RDD. */
def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
rddCheckpointDataPath(sc, rddId).foreach { path =>
val fs = path.getFileSystem(sc.hadoopConfiguration)
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 69181edb9ad44..6ae47894598be 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -17,8 +17,7 @@
package org.apache.spark.rpc
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration.FiniteDuration
+import scala.concurrent.Future
import scala.reflect.ClassTag
import org.apache.spark.util.RpcUtils
@@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
private[this] val maxRetries = RpcUtils.numRetries(conf)
private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
- private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf)
+ private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
/**
* return the address for the [[RpcEndpointRef]]
@@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
*
* This method only sends the message once and never retries.
*/
- def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T]
+ def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
@@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
- def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = {
+ def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
var attempts = 0
var lastException: Exception = null
@@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
attempts += 1
try {
val future = ask[T](message, timeout)
- val result = Await.result(future, timeout)
+ val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
@@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
lastException = e
logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
}
- Thread.sleep(retryWaitMs)
+
+ if (attempts < maxRetries) {
+ Thread.sleep(retryWaitMs)
+ }
}
throw new SparkException(
s"Error sending message [message = $message]", lastException)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 12b6b28d4d7ec..1709bdf560b6f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -18,8 +18,10 @@
package org.apache.spark.rpc
import java.net.URI
+import java.util.concurrent.TimeoutException
-import scala.concurrent.{Await, Future}
+import scala.concurrent.{Awaitable, Await, Future}
+import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.spark.{SecurityManager, SparkConf}
@@ -66,7 +68,7 @@ private[spark] object RpcEnv {
*/
private[spark] abstract class RpcEnv(conf: SparkConf) {
- private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf)
+ private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf)
/**
* Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
@@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action.
*/
def setupEndpointRefByURI(uri: String): RpcEndpointRef = {
- Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout)
+ defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
}
/**
@@ -158,6 +160,8 @@ private[spark] case class RpcAddress(host: String, port: Int) {
val hostPort: String = host + ":" + port
override val toString: String = hostPort
+
+ def toSparkURL: String = "spark://" + hostPort
}
@@ -182,3 +186,107 @@ private[spark] object RpcAddress {
RpcAddress(host, port)
}
}
+
+
+/**
+ * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
+ */
+private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
+ extends TimeoutException(message) { initCause(cause) }
+
+
+/**
+ * Associates a timeout with a description so that a when a TimeoutException occurs, additional
+ * context about the timeout can be amended to the exception message.
+ * @param duration timeout duration in seconds
+ * @param timeoutProp the configuration property that controls this timeout
+ */
+private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String)
+ extends Serializable {
+
+ /** Amends the standard message of TimeoutException to include the description */
+ private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
+ new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te)
+ }
+
+ /**
+ * PartialFunction to match a TimeoutException and add the timeout description to the message
+ *
+ * @note This can be used in the recover callback of a Future to add to a TimeoutException
+ * Example:
+ * val timeout = new RpcTimeout(5 millis, "short timeout")
+ * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
+ */
+ def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
+ // The exception has already been converted to a RpcTimeoutException so just raise it
+ case rte: RpcTimeoutException => throw rte
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
+ case te: TimeoutException => throw createRpcTimeoutException(te)
+ }
+
+ /**
+ * Wait for the completed result and return it. If the result is not available within this
+ * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout.
+ * @param awaitable the `Awaitable` to be awaited
+ * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
+ * is still not ready
+ */
+ def awaitResult[T](awaitable: Awaitable[T]): T = {
+ try {
+ Await.result(awaitable, duration)
+ } catch addMessageIfTimeout
+ }
+}
+
+
+private[spark] object RpcTimeout {
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @throws NoSuchElementException if property is not set
+ */
+ def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup the timeout property in the configuration and create
+ * a RpcTimeout with the property key in the description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutProp property key for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if property not found
+ */
+ def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = {
+ val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds }
+ new RpcTimeout(timeout, timeoutProp)
+ }
+
+ /**
+ * Lookup prioritized list of timeout properties in the configuration
+ * and create a RpcTimeout with the first set property key in the
+ * description.
+ * Uses the given default value if property is not set
+ * @param conf configuration properties containing the timeout
+ * @param timeoutPropList prioritized list of property keys for the timeout in seconds
+ * @param defaultValue default timeout value in seconds if no properties found
+ */
+ def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = {
+ require(timeoutPropList.nonEmpty)
+
+ // Find the first set property or use the default value with the first property
+ val itr = timeoutPropList.iterator
+ var foundProp: Option[(String, String)] = None
+ while (itr.hasNext && foundProp.isEmpty){
+ val propKey = itr.next()
+ conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
+ }
+ val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
+ val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds }
+ new RpcTimeout(timeout, finalProp._1)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 0161962cde073..f2d87f68341af 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
-import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
@@ -180,10 +179,10 @@ private[spark] class AkkaRpcEnv private[akka] (
})
} catch {
case NonFatal(e) =>
- if (needReply) {
- // If the sender asks a reply, we should send the error back to the sender
- _sender ! AkkaFailure(e)
- } else {
+ _sender ! AkkaFailure(e)
+ if (!needReply) {
+ // If the sender does not require a reply, it may not handle the exception. So we rethrow
+ // "e" to make sure it will be processed.
throw e
}
}
@@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] (
override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
import actorSystem.dispatcher
- actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout).
- map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
+ actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
+ map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
+ // this is just in case there is a timeout from creating the future in resolveOne, we want the
+ // exception to indicate the conf that determines the timeout
+ recover(defaultLookupTimeout.addMessageIfTimeout)
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
@@ -295,8 +297,8 @@ private[akka] class AkkaRpcEndpointRef(
actorRef ! AkkaMessage(message, false)
}
- override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
- actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
+ override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
+ actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
@@ -307,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef(
}
case AkkaFailure(e) =>
Future.failed(e)
- }(ThreadUtils.sameThread).mapTo[T]
+ }(ThreadUtils.sameThread).mapTo[T].
+ recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c6029675eab0e..11b12edf7eaf1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
@@ -188,7 +189,7 @@ class DAGScheduler(
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
- BlockManagerHeartbeat(blockManagerId), 600 seconds)
+ BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
}
// Called by TaskScheduler when an executor fails.
@@ -870,7 +871,7 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
- stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
+ stage.makeNewStageAttempt(partitionsToCompute.size)
outputCommitCoordinator.stageStart(stage.id)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
@@ -935,8 +936,8 @@ class DAGScheduler(
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
- taskScheduler.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties))
+ taskScheduler.submitTasks(new TaskSet(
+ tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 529a5b2bf1a0d..62b05033a9281 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -140,7 +140,9 @@ private[spark] class EventLoggingListener(
/** Log the event as JSON. */
private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
val eventJson = JsonProtocol.sparkEventToJson(event)
+ // scalastyle:off println
writer.foreach(_.println(compact(render(eventJson))))
+ // scalastyle:on println
if (flushLogger) {
writer.foreach(_.flush())
hadoopDataStream.foreach(hadoopFlushMethod.invoke(_))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index e55b76c36cc5f..f96eb8ca0ae00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
val date = new Date(System.currentTimeMillis())
writeInfo = dateFormat.get.format(date) + ": " + info
}
+ // scalastyle:off println
jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo))
+ // scalastyle:on println
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 14ab2b86e1b77..b86724de2cb73 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -62,28 +62,28 @@ private[spark] abstract class Stage(
var pendingTasks = new HashSet[Task[_]]
+ /** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0
val name = callSite.shortForm
val details = callSite.longForm
- /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */
- var latestInfo: StageInfo = StageInfo.fromStage(this)
+ /**
+ * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
+ * here, before any attempts have actually been created, because the DAGScheduler uses this
+ * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts
+ * have been created).
+ */
+ private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)
- /** Return a new attempt id, starting with 0. */
- def newAttemptId(): Int = {
- val id = nextAttemptId
+ /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
+ def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = {
+ _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute))
nextAttemptId += 1
- id
}
- /**
- * The id for the **next** stage attempt.
- *
- * The unusual meaning of this method means its unlikely to hold the value you are interested in
- * -- you probably want to use [[latestInfo.attemptId]]
- */
- private[spark] def attemptId: Int = nextAttemptId
+ /** Returns the StageInfo for the most recent attempt for this stage. */
+ def latestInfo: StageInfo = _latestInfo
override final def hashCode(): Int = id
override final def equals(other: Any): Boolean = other match {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e439d2a7e1229..5d2abbc67e9d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -70,12 +70,12 @@ private[spark] object StageInfo {
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
* sequence of narrow dependencies should also be associated with this Stage.
*/
- def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = {
+ def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
new StageInfo(
stage.id,
- stage.attemptId,
+ attemptId,
stage.name,
numTasks.getOrElse(stage.numTasks),
rddInfos,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ccf1dc5af6120..687ae9620460f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend(
val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory,
command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor)
- client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
+ client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
waitForRegistration()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 190ff61d689d1..bc67abb5df446 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend(
private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint(
YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv))
- private implicit val askTimeout = RpcUtils.askTimeout(sc.conf)
+ private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf)
/**
* Request executors from the ApplicationMaster by specifying the total number desired.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 6b8edca5aa485..cbade131494bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -18,18 +18,21 @@
package org.apache.spark.scheduler.cluster.mesos
import java.io.File
-import java.util.{Collections, List => JList}
+import java.util.{List => JList, Collections}
+import java.util.concurrent.locks.ReentrantLock
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
+import com.google.common.collect.HashBiMap
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.{Scheduler => MScheduler, _}
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
+import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -60,12 +63,34 @@ private[spark] class CoarseMesosSchedulerBackend(
val slaveIdsWithExecutors = new HashSet[String]
- val taskIdToSlaveId = new HashMap[Int, String]
- val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
+ val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String]
+ // How many times tasks on each slave failed
+ val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
+
+ /**
+ * The total number of executors we aim to have. Undefined when not using dynamic allocation
+ * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]].
+ */
+ private var executorLimitOption: Option[Int] = None
+
+ /**
+ * Return the current executor limit, which may be [[Int.MaxValue]]
+ * before properly initialized.
+ */
+ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue)
+
+ private val pendingRemovedSlaveIds = new HashSet[String]
+ // private lock object protecting mutable state above. Using the intrinsic lock
+ // may lead to deadlocks since the superclass might also try to lock
+ private val stateLock = new ReentrantLock
val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0)
+ // Offer constraints
+ private val slaveOfferConstraints =
+ parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
+
var nextMesosTaskId = 0
@volatile var appId: String = _
@@ -82,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend(
startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo)
}
- def createCommand(offer: Offer, numCores: Int): CommandInfo = {
+ def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = {
val executorSparkHome = conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome())
.getOrElse {
@@ -116,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = sc.env.rpcEnv.uriOf(
- SparkEnv.driverActorSystemName,
- RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val uri = conf.getOption("spark.executor.uri")
.orElse(Option(System.getenv("SPARK_EXECUTOR_URI")))
@@ -129,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend(
command.setValue(
"%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
.format(prefixEnv, runScript) +
- s" --driver-url $driverUrl" +
+ s" --driver-url $driverURL" +
s" --executor-id ${offer.getSlaveId.getValue}" +
s" --hostname ${offer.getHostname}" +
s" --cores $numCores" +
@@ -138,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend(
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.get.split('/').last.split('.').head
+ val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString)
command.setValue(
s"cd $basename*; $prefixEnv " +
"./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" +
- s" --driver-url $driverUrl" +
- s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --driver-url $driverURL" +
+ s" --executor-id $executorId" +
s" --hostname ${offer.getHostname}" +
s" --cores $numCores" +
s" --app-id $appId")
@@ -151,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend(
command.build()
}
+ protected def driverURL: String = {
+ if (conf.contains("spark.testing")) {
+ "driverURL"
+ } else {
+ sc.env.rpcEnv.uriOf(
+ SparkEnv.driverActorSystemName,
+ RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
+ }
+ }
+
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
@@ -168,15 +201,19 @@ private[spark] class CoarseMesosSchedulerBackend(
* unless we've already launched more than we wanted to.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- synchronized {
+ stateLock.synchronized {
val filters = Filters.newBuilder().setRefuseSeconds(5).build()
-
for (offer <- offers) {
- val slaveId = offer.getSlaveId.toString
+ val offerAttributes = toAttributeMap(offer.getAttributesList)
+ val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
+ val slaveId = offer.getSlaveId.getValue
val mem = getResource(offer.getResourcesList, "mem")
val cpus = getResource(offer.getResourcesList, "cpus").toInt
- if (totalCoresAcquired < maxCores &&
- mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ val id = offer.getId.getValue
+ if (taskIdToSlaveId.size < executorLimit &&
+ totalCoresAcquired < maxCores &&
+ meetsConstraints &&
+ mem >= calculateTotalMemory(sc) &&
cpus >= 1 &&
failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
!slaveIdsWithExecutors.contains(slaveId)) {
@@ -190,42 +227,36 @@ private[spark] class CoarseMesosSchedulerBackend(
val task = MesosTaskInfo.newBuilder()
.setTaskId(TaskID.newBuilder().setValue(taskId.toString).build())
.setSlaveId(offer.getSlaveId)
- .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
+ .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId))
.setName("Task " + taskId)
.addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem",
- MemoryUtils.calculateTotalMemory(sc)))
+ .addResources(createResource("mem", calculateTotalMemory(sc)))
sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
MesosSchedulerBackendUtil
- .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder())
+ .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder)
}
+ // accept the offer and launch the task
+ logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
d.launchTasks(
- Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters)
+ Collections.singleton(offer.getId),
+ Collections.singleton(task.build()), filters)
} else {
- // Filter it out
- d.launchTasks(
- Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters)
+ // Decline the offer
+ logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
+ d.declineOffer(offer.getId)
}
}
}
}
- /** Build a Mesos resource protobuf object */
- private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
- Resource.newBuilder()
- .setName(resourceName)
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
- .build()
- }
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
val taskId = status.getTaskId.getValue.toInt
val state = status.getState
logInfo("Mesos task " + taskId + " is now " + state)
- synchronized {
+ stateLock.synchronized {
if (TaskState.isFinished(TaskState.fromMesos(state))) {
val slaveId = taskIdToSlaveId(taskId)
slaveIdsWithExecutors -= slaveId
@@ -243,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend(
"is Spark installed on it?")
}
}
+ executorTerminated(d, slaveId, s"Executor finished with state $state")
// In case we'd rejected everything before but have now lost a node
- mesosDriver.reviveOffers()
+ d.reviveOffers()
}
}
}
@@ -263,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
- override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
- logInfo("Mesos slave lost: " + slaveId.getValue)
- synchronized {
- if (slaveIdsWithExecutors.contains(slaveId.getValue)) {
- // Note that the slave ID corresponds to the executor ID on that slave
- slaveIdsWithExecutors -= slaveId.getValue
- removeExecutor(slaveId.getValue, "Mesos slave lost")
+ /**
+ * Called when a slave is lost or a Mesos task finished. Update local view on
+ * what tasks are running and remove the terminated slave from the list of pending
+ * slave IDs that we might have asked to be killed. It also notifies the driver
+ * that an executor was removed.
+ */
+ private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = {
+ stateLock.synchronized {
+ if (slaveIdsWithExecutors.contains(slaveId)) {
+ val slaveIdToTaskId = taskIdToSlaveId.inverse()
+ if (slaveIdToTaskId.contains(slaveId)) {
+ val taskId: Int = slaveIdToTaskId.get(slaveId)
+ taskIdToSlaveId.remove(taskId)
+ removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason)
+ }
+ // TODO: This assumes one Spark executor per Mesos slave,
+ // which may no longer be true after SPARK-5095
+ pendingRemovedSlaveIds -= slaveId
+ slaveIdsWithExecutors -= slaveId
}
}
}
- override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
+ private def sparkExecutorId(slaveId: String, taskId: String): String = {
+ s"$slaveId/$taskId"
+ }
+
+ override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = {
+ logInfo("Mesos slave lost: " + slaveId.getValue)
+ executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue)
+ }
+
+ override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = {
logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
slaveLost(d, s)
}
@@ -285,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend(
super.applicationId
}
+ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ // We don't truly know if we can fulfill the full amount of executors
+ // since at coarse grain it depends on the amount of slaves available.
+ logInfo("Capping the total amount of executors to " + requestedTotal)
+ executorLimitOption = Some(requestedTotal)
+ true
+ }
+
+ override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ if (mesosDriver == null) {
+ logWarning("Asked to kill executors before the Mesos driver was started.")
+ return false
+ }
+
+ val slaveIdToTaskId = taskIdToSlaveId.inverse()
+ for (executorId <- executorIds) {
+ val slaveId = executorId.split("/")(0)
+ if (slaveIdToTaskId.contains(slaveId)) {
+ mesosDriver.killTask(
+ TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build())
+ pendingRemovedSlaveIds += slaveId
+ } else {
+ logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler")
+ }
+ }
+ // no need to adjust `executorLimitOption` since the AllocationManager already communicated
+ // the desired limit through a call to `doRequestTotalExecutors`.
+ // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]]
+ true
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 1067a7f1caf4c..d3a20f822176e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable
import org.apache.mesos.Protos.TaskStatus.Reason
import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
import org.apache.mesos.{Scheduler, SchedulerDriver}
+
import org.apache.spark.deploy.mesos.MesosDriverDescription
import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse}
import org.apache.spark.metrics.MetricsSystem
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 49de85ef48ada..d72e2af456e15 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
+import org.apache.mesos.{Scheduler => MScheduler, _}
import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.protobuf.ByteString
-import org.apache.mesos.{Scheduler => MScheduler, _}
+import org.apache.spark.{SparkContext, SparkException, TaskState}
import org.apache.spark.executor.MesosExecutorBackend
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkContext, SparkException, TaskState}
/**
* A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a
@@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend(
private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1)
+ // Offer constraints
+ private[this] val slaveOfferConstraints =
+ parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
+
@volatile var appId: String = _
override def start() {
@@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend(
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
- throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
- }
+ throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
+ }
val environment = Environment.newBuilder()
sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
environment.addVariables(
@@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend(
.setName("cpus")
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder()
- .setValue(mesosExecutorCores).build())
+ .setValue(mesosExecutorCores).build())
.build()
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
.setScalar(
Value.Scalar.newBuilder()
- .setValue(MemoryUtils.calculateTotalMemory(sc)).build())
+ .setValue(calculateTotalMemory(sc)).build())
.build()
val executorInfo = MesosExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
@@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend(
val mem = getResource(o.getResourcesList, "mem")
val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
- (mem >= MemoryUtils.calculateTotalMemory(sc) &&
- // need at least 1 for executor, 1 for task
- cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) ||
- (slaveIdsWithExecutors.contains(slaveId) &&
- cpus >= scheduler.CPUS_PER_TASK)
+ val offerAttributes = toAttributeMap(o.getAttributesList)
+
+ // check if all constraints are satisfield
+ // 1. Attribute constraints
+ // 2. Memory requirements
+ // 3. CPU requirements - need at least 1 for executor, 1 for task
+ val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes)
+ val meetsMemoryRequirements = mem >= calculateTotalMemory(sc)
+ val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)
+
+ val meetsRequirements =
+ (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) ||
+ (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK)
+
+ // add some debug messaging
+ val debugstr = if (meetsRequirements) "Accepting" else "Declining"
+ val id = o.getId.getValue
+ logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
+
+ meetsRequirements
}
+ // Decline offers we ruled out immediately
+ unUsableOffers.foreach(o => d.declineOffer(o.getId))
+
val workerOffers = usableOffers.map { o =>
val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
getResource(o.getResourcesList, "cpus").toInt
@@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend(
val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
acceptedOffers
.foreach { offer =>
- offer.foreach { taskDesc =>
- val slaveId = taskDesc.executorId
- slaveIdsWithExecutors += slaveId
- slavesIdsOfAcceptedOffers += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
- .add(createMesosTask(taskDesc, slaveId))
- }
+ offer.foreach { taskDesc =>
+ val slaveId = taskDesc.executorId
+ slaveIdsWithExecutors += slaveId
+ slavesIdsOfAcceptedOffers += slaveId
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
+ .add(createMesosTask(taskDesc, slaveId))
}
+ }
// Reply to the offers
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
@@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend(
d.declineOffer(o.getId)
}
- // Decline offers we ruled out immediately
- unUsableOffers.foreach(o => d.declineOffer(o.getId))
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index d11228f3d016a..925702e63afd3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -17,14 +17,17 @@
package org.apache.spark.scheduler.cluster.mesos
-import java.util.List
+import java.util.{List => JList}
import java.util.concurrent.CountDownLatch
import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
-import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status}
-import org.apache.mesos.{MesosSchedulerDriver, Scheduler}
-import org.apache.spark.Logging
+import com.google.common.base.Splitter
+import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos}
+import org.apache.mesos.Protos._
+import org.apache.mesos.protobuf.GeneratedMessage
+import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.util.Utils
/**
@@ -36,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
private final val registerLatch = new CountDownLatch(1)
// Driver for talking to Mesos
- protected var mesosDriver: MesosSchedulerDriver = null
+ protected var mesosDriver: SchedulerDriver = null
/**
* Starts the MesosSchedulerDriver with the provided information. This method returns
@@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
/**
* Get the amount of resources for the specified type from the resource list
*/
- protected def getResource(res: List[Resource], name: String): Double = {
+ protected def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
0.0
}
+
+ /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */
+ protected def getAttribute(attr: Attribute): (String, Set[String]) = {
+ (attr.getName, attr.getText.getValue.split(',').toSet)
+ }
+
+
+ /** Build a Mesos resource protobuf object */
+ protected def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ Resource.newBuilder()
+ .setName(resourceName)
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
+ .build()
+ }
+
+ /**
+ * Converts the attributes from the resource offer into a Map of name -> Attribute Value
+ * The attribute values are the mesos attribute types and they are
+ * @param offerAttributes
+ * @return
+ */
+ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = {
+ offerAttributes.map(attr => {
+ val attrValue = attr.getType match {
+ case Value.Type.SCALAR => attr.getScalar
+ case Value.Type.RANGES => attr.getRanges
+ case Value.Type.SET => attr.getSet
+ case Value.Type.TEXT => attr.getText
+ }
+ (attr.getName, attrValue)
+ }).toMap
+ }
+
+
+ /**
+ * Match the requirements (if any) to the offer attributes.
+ * if attribute requirements are not specified - return true
+ * else if attribute is defined and no values are given, simple attribute presence is performed
+ * else if attribute name and value is specified, subset match is performed on slave attributes
+ */
+ def matchesAttributeRequirements(
+ slaveOfferConstraints: Map[String, Set[String]],
+ offerAttributes: Map[String, GeneratedMessage]): Boolean = {
+ slaveOfferConstraints.forall {
+ // offer has the required attribute and subsumes the required values for that attribute
+ case (name, requiredValues) =>
+ offerAttributes.get(name) match {
+ case None => false
+ case Some(_) if requiredValues.isEmpty => true // empty value matches presence
+ case Some(scalarValue: Value.Scalar) =>
+ // check if provided values is less than equal to the offered values
+ requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue)
+ case Some(rangeValue: Value.Range) =>
+ val offerRange = rangeValue.getBegin to rangeValue.getEnd
+ // Check if there is some required value that is between the ranges specified
+ // Note: We only support the ability to specify discrete values, in the future
+ // we may expand it to subsume ranges specified with a XX..YY value or something
+ // similar to that.
+ requiredValues.map(_.toLong).exists(offerRange.contains(_))
+ case Some(offeredValue: Value.Set) =>
+ // check if the specified required values is a subset of offered set
+ requiredValues.subsetOf(offeredValue.getItemList.toSet)
+ case Some(textValue: Value.Text) =>
+ // check if the specified value is equal, if multiple values are specified
+ // we succeed if any of them match.
+ requiredValues.contains(textValue.getValue)
+ }
+ }
+ }
+
+ /**
+ * Parses the attributes constraints provided to spark and build a matching data struct:
+ * Map[, Set[values-to-match]]
+ * The constraints are specified as ';' separated key-value pairs where keys and values
+ * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for
+ * multiple values (comma separated). For example:
+ * {{{
+ * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b")
+ * // would result in
+ *
+ * Map(
+ * "tachyon" -> Set("true"),
+ * "zone": -> Set("us-east-1a", "us-east-1b")
+ * )
+ * }}}
+ *
+ * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/
+ * https://github.com/apache/mesos/blob/master/src/common/values.cpp
+ * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp
+ *
+ * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated
+ * by ':')
+ * @return Map of constraints to match resources offers.
+ */
+ def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = {
+ /*
+ Based on mesos docs:
+ attributes : attribute ( ";" attribute )*
+ attribute : labelString ":" ( labelString | "," )+
+ labelString : [a-zA-Z0-9_/.-]
+ */
+ val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':')
+ // kv splitter
+ if (constraintsVal.isEmpty) {
+ Map()
+ } else {
+ try {
+ Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map {
+ case (k, v) =>
+ if (v == null || v.isEmpty) {
+ (k, Set[String]())
+ } else {
+ (k, v.split(',').toSet)
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e)
+ }
+ }
+ }
+
+ // These defaults copied from YARN
+ private val MEMORY_OVERHEAD_FRACTION = 0.10
+ private val MEMORY_OVERHEAD_MINIMUM = 384
+
+ /**
+ * Return the amount of memory to allocate to each executor, taking into account
+ * container overheads.
+ * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value
+ * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM
+ * (whichever is larger)
+ */
+ def calculateTotalMemory(sc: SparkContext): Int = {
+ sc.conf.getInt("spark.mesos.executor.memoryOverhead",
+ math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) +
+ sc.executorMemory
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 3078a1b10be8b..776e5d330e3c7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.local
+import java.io.File
+import java.net.URL
import java.nio.ByteBuffer
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
@@ -40,6 +42,7 @@ private case class StopExecutor()
*/
private[spark] class LocalEndpoint(
override val rpcEnv: RpcEnv,
+ userClassPath: Seq[URL],
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int)
@@ -51,7 +54,7 @@ private[spark] class LocalEndpoint(
private val localExecutorHostname = "localhost"
private val executor = new Executor(
- localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
+ localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
override def receive: PartialFunction[Any, Unit] = {
case ReviveOffers =>
@@ -97,10 +100,22 @@ private[spark] class LocalBackend(
private val appId = "local-" + System.currentTimeMillis
var localEndpoint: RpcEndpointRef = null
+ private val userClassPath = getUserClasspath(conf)
+
+ /**
+ * Returns a list of URLs representing the user classpath.
+ *
+ * @param conf Spark configuration.
+ */
+ def getUserClasspath(conf: SparkConf): Seq[URL] = {
+ val userClassPathStr = conf.getOption("spark.executor.extraClassPath")
+ userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL)
+ }
override def start() {
localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint(
- "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores))
+ "LocalBackendEndpoint",
+ new LocalEndpoint(SparkEnv.get.rpcEnv, userClassPath, scheduler, this, totalCores))
}
override def stop() {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 7cdae22b0e253..f70f701494dbf 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -33,7 +33,7 @@ class BlockManagerMaster(
isDriver: Boolean)
extends Logging {
- val timeout = RpcUtils.askTimeout(conf)
+ val timeout = RpcUtils.askRpcTimeout(conf)
/** Remove a dead executor from the driver endpoint. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -106,7 +106,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -118,7 +118,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -132,7 +132,7 @@ class BlockManagerMaster(
s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
}
@@ -176,8 +176,8 @@ class BlockManagerMaster(
CanBuildFrom[Iterable[Future[Option[BlockStatus]]],
Option[BlockStatus],
Iterable[Option[BlockStatus]]]]
- val blockStatus = Await.result(
- Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout)
+ val blockStatus = timeout.awaitResult(
+ Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread))
if (blockStatus == null) {
throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
}
@@ -199,7 +199,7 @@ class BlockManagerMaster(
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg)
- Await.result(future, timeout)
+ timeout.awaitResult(future)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 91ef86389a0c3..5f537692a16c5 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
(blockId, getFile(blockId))
}
+ /**
+ * Create local directories for storing block data. These directories are
+ * located inside configured local directories and won't
+ * be deleted on JVM exit when using the external shuffle service.
+ */
private def createLocalDirs(conf: SparkConf): Array[File] = {
- Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
+ Utils.getConfiguredLocalDirs(conf).flatMap { rootDir =>
try {
val localDir = Utils.createDirectory(rootDir, "blockmgr")
+ Utils.chmod700(localDir)
logInfo(s"Created local directory at $localDir")
Some(localDir)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 06e616220c706..c8356467fab87 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging {
response.setStatus(HttpServletResponse.SC_OK)
val result = servletParams.responder(request)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ // scalastyle:off println
response.getWriter.println(servletParams.extractFn(result))
+ // scalastyle:on println
} else {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
@@ -210,10 +212,16 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = ""): ServerInfo = {
- val collection = new ContextHandlerCollection
- collection.setHandlers(handlers.toArray)
addFilters(handlers, conf)
+ val collection = new ContextHandlerCollection
+ val gzipHandlers = handlers.map { h =>
+ val gzipHandler = new GzipHandler
+ gzipHandler.setHandler(h)
+ gzipHandler
+ }
+ collection.setHandlers(gzipHandlers.toArray)
+
// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
val server = new Server(new InetSocketAddress(hostName, currentPort))
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index ba03acdb38cc5..5a8c2914314c2 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator {
def main(args: Array[String]) {
if (args.length < 3) {
+ // scalastyle:off println
println(
- "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
+ "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
"[master] [FIFO|FAIR] [#job set (4 jobs per set)]")
+ // scalastyle:on println
System.exit(1)
}
@@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator {
for ((desc, job) <- jobs) {
new Thread {
override def run() {
+ // scalastyle:off println
try {
setProperties(desc)
job()
@@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator {
} finally {
barrier.release()
}
+ // scalastyle:on println
}
}.start
Thread.sleep(INTER_JOB_WAIT_MS)
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 39583af14390d..a88fc4c37d3c9 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ui.exec
import scala.collection.mutable.HashMap
-import org.apache.spark.{ExceptionFailure, SparkContext}
+import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageStatus, StorageStatusListener}
@@ -92,15 +92,22 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
val info = taskEnd.taskInfo
if (info != null) {
val eid = info.executorId
- executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
- executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
taskEnd.reason match {
+ case Resubmitted =>
+ // Note: For resubmitted tasks, we continue to use the metrics that belong to the
+ // first attempt of this task. This may not be 100% accurate because the first attempt
+ // could have failed half-way through. The correct fix would be to keep track of the
+ // metrics added by each attempt, but this is much more complicated.
+ return
case e: ExceptionFailure =>
executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1
case _ =>
executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
}
+ executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
+ executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
+
// Update shuffle read/write
val metrics = taskEnd.taskMetrics
if (metrics != null) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index e96bf49d0dd14..ff0a339a39c65 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -332,7 +332,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- getGettingResultTime(info).toDouble
+ getGettingResultTime(info, currentTime).toDouble
}
val gettingResultQuantiles =
@@ -346,7 +346,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- getSchedulerDelay(info, metrics.get).toDouble
+ getSchedulerDelay(info, metrics.get, currentTime).toDouble
}
val schedulerDelayTitle =
Scheduler Delay
@@ -544,7 +544,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val serializationTimeProportion = toProportion(serializationTime)
val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L)
val deserializationTimeProportion = toProportion(deserializationTime)
- val gettingResultTime = getGettingResultTime(taskUIData.taskInfo)
+ val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime)
val gettingResultTimeProportion = toProportion(gettingResultTime)
val schedulerDelay = totalExecutionTime -
(executorComputingTime + shuffleReadTime + shuffleWriteTime +
@@ -570,6 +570,35 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val index = taskInfo.index
val attempt = taskInfo.attempt
+
+ val svgTag =
+ if (totalExecutionTime == 0) {
+ // SPARK-8705: Avoid invalid attribute error in JavaScript if execution time is 0
+ """"""
+ } else {
+ s"""""".stripMargin
+ }
val timelineObject =
s"""
|{
@@ -595,32 +624,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
| Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)}
| Result Serialization Time: ${UIUtils.formatDuration(serializationTime)}
| Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}">
- |',
+ |$svgTag',
|'start': new Date($launchTime),
|'end': new Date($finishTime)
|}
- |""".stripMargin.replaceAll("\n", " ")
+ |""".stripMargin.replaceAll("""[\r\n]+""", " ")
timelineObject
}.mkString("[", ",", "]")
@@ -677,11 +685,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
else metrics.map(_.executorRunTime).getOrElse(1L)
val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
- val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L)
+ val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
- val gettingResultTime = getGettingResultTime(info)
+ val gettingResultTime = getGettingResultTime(info, currentTime)
val maybeAccumulators = info.accumulables
val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"}
@@ -844,32 +852,31 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{errorSummary}{details}
}
- private def getGettingResultTime(info: TaskInfo): Long = {
- if (info.gettingResultTime > 0) {
- if (info.finishTime > 0) {
+ private def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = {
+ if (info.gettingResult) {
+ if (info.finished) {
info.finishTime - info.gettingResultTime
} else {
// The task is still fetching the result.
- System.currentTimeMillis - info.gettingResultTime
+ currentTime - info.gettingResultTime
}
} else {
0L
}
}
- private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
- val totalExecutionTime =
- if (info.gettingResult) {
- info.gettingResultTime - info.launchTime
- } else if (info.finished) {
- info.finishTime - info.launchTime
- } else {
- 0
- }
- val executorOverhead = (metrics.executorDeserializeTime +
- metrics.resultSerializationTime)
- math.max(
- 0,
- totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info))
+ private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = {
+ if (info.finished) {
+ val totalExecutionTime = info.finishTime - info.launchTime
+ val executorOverhead = (metrics.executorDeserializeTime +
+ metrics.resultSerializationTime)
+ math.max(
+ 0,
+ totalExecutionTime - metrics.executorRunTime - executorOverhead -
+ getGettingResultTime(info, currentTime))
+ } else {
+ // The task is still running and the metrics like executorRunTime are not available.
+ 0L
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 96aa2fe164703..c179833e5b06a 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -18,8 +18,6 @@
package org.apache.spark.util
import scala.collection.JavaConversions.mapAsJavaMap
-import scala.concurrent.Await
-import scala.concurrent.duration.FiniteDuration
import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
import akka.pattern.ask
@@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.rpc.RpcTimeout
/**
* Various utility classes for working with Akka.
@@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging {
def askWithReply[T](
message: Any,
actor: ActorRef,
- timeout: FiniteDuration): T = {
+ timeout: RpcTimeout): T = {
askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
}
@@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging {
actor: ActorRef,
maxAttempts: Int,
retryInterval: Long,
- timeout: FiniteDuration): T = {
+ timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
throw new SparkException(s"Error sending message [message = $message]" +
@@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging {
while (attempts < maxAttempts) {
attempts += 1
try {
- val future = actor.ask(message)(timeout)
- val result = Await.result(future, timeout)
+ val future = actor.ask(message)(timeout.duration)
+ val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
@@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging {
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name)
- val timeout = RpcUtils.lookupTimeout(conf)
+ val timeout = RpcUtils.lookupRpcTimeout(conf)
logInfo(s"Connecting to $name: $url")
- Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}
def makeExecutorRef(
@@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging {
val executorActorSystemName = SparkEnv.executorActorSystemName
Utils.checkHost(host, "Expected hostname")
val url = address(protocol(actorSystem), executorActorSystemName, host, port, name)
- val timeout = RpcUtils.lookupTimeout(conf)
+ val timeout = RpcUtils.lookupRpcTimeout(conf)
logInfo(s"Connecting to $name: $url")
- Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}
def protocol(actorSystem: ActorSystem): String = {
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
index 1bab707235b89..950b69f7db641 100644
--- a/core/src/main/scala/org/apache/spark/util/Distribution.scala
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va
}
def showQuantiles(out: PrintStream = System.out): Unit = {
+ // scalastyle:off println
out.println("min\t25%\t50%\t75%\tmax")
getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
out.println
+ // scalastyle:on println
}
def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx))
@@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va
* @param out
*/
def summary(out: PrintStream = System.out) {
+ // scalastyle:off println
out.println(statCounter)
showQuantiles(out)
+ // scalastyle:on println
}
}
@@ -80,8 +84,10 @@ private[spark] object Distribution {
}
def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) {
+ // scalastyle:off println
out.println("min\t25%\t50%\t75%\tmax")
quantiles.foreach{q => out.print(q + "\t")}
out.println
+ // scalastyle:on println
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
index f16cc8e7e42c6..7578a3b1d85f2 100644
--- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
@@ -17,11 +17,11 @@
package org.apache.spark.util
-import scala.concurrent.duration._
+import scala.concurrent.duration.FiniteDuration
import scala.language.postfixOps
import org.apache.spark.{SparkEnv, SparkConf}
-import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout}
object RpcUtils {
@@ -47,14 +47,22 @@ object RpcUtils {
}
/** Returns the default Spark timeout to use for RPC ask operations. */
+ private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = {
+ RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s")
+ }
+
+ @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0")
def askTimeout(conf: SparkConf): FiniteDuration = {
- conf.getTimeAsSeconds("spark.rpc.askTimeout",
- conf.get("spark.network.timeout", "120s")) seconds
+ askRpcTimeout(conf).duration
}
/** Returns the default Spark timeout to use for RPC remote endpoint lookup. */
+ private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = {
+ RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s")
+ }
+
+ @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0")
def lookupTimeout(conf: SparkConf): FiniteDuration = {
- conf.getTimeAsSeconds("spark.rpc.lookupTimeout",
- conf.get("spark.network.timeout", "120s")) seconds
+ lookupRpcTimeout(conf).duration
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 19157af5b6f4d..b6b932104a94d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -80,6 +80,12 @@ private[spark] object Utils extends Logging {
*/
val TEMP_DIR_SHUTDOWN_PRIORITY = 25
+ /**
+ * Define a default value for driver memory here since this value is referenced across the code
+ * base and nearly all files already use Utils.scala
+ */
+ val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt
+
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@volatile private var localRootDirs: Array[String] = null
@@ -727,7 +733,12 @@ private[spark] object Utils extends Logging {
localRootDirs
}
- private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = {
+ /**
+ * Return the configured local directories where Spark can write files. This
+ * method does not create any directories on its own, it only encapsulates the
+ * logic of locating the local directories according to deployment mode.
+ */
+ def getConfiguredLocalDirs(conf: SparkConf): Array[String] = {
if (isRunningInYarnContainer(conf)) {
// If we are in yarn mode, systems can have different disk layouts so we must set it
// to what Yarn on this system said was available. Note this assumes that Yarn has
@@ -743,27 +754,29 @@ private[spark] object Utils extends Logging {
Option(conf.getenv("SPARK_LOCAL_DIRS"))
.getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
.split(",")
- .flatMap { root =>
- try {
- val rootDir = new File(root)
- if (rootDir.exists || rootDir.mkdirs()) {
- val dir = createTempDir(root)
- chmod700(dir)
- Some(dir.getAbsolutePath)
- } else {
- logError(s"Failed to create dir in $root. Ignoring this directory.")
- None
- }
- } catch {
- case e: IOException =>
- logError(s"Failed to create local root dir in $root. Ignoring this directory.")
- None
- }
- }
- .toArray
}
}
+ private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = {
+ getConfiguredLocalDirs(conf).flatMap { root =>
+ try {
+ val rootDir = new File(root)
+ if (rootDir.exists || rootDir.mkdirs()) {
+ val dir = createTempDir(root)
+ chmod700(dir)
+ Some(dir.getAbsolutePath)
+ } else {
+ logError(s"Failed to create dir in $root. Ignoring this directory.")
+ None
+ }
+ } catch {
+ case e: IOException =>
+ logError(s"Failed to create local root dir in $root. Ignoring this directory.")
+ None
+ }
+ }.toArray
+ }
+
/** Get the Yarn approved local directories. */
private def getYarnLocalDirs(conf: SparkConf): String = {
// Hadoop 0.23 and 2.x have different Environment variable names for the
@@ -2333,3 +2346,36 @@ private[spark] class RedirectThread(
}
}
}
+
+/**
+ * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it
+ * in a circular buffer. The current contents of the buffer can be accessed using
+ * the toString method.
+ */
+private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
+ var pos: Int = 0
+ var buffer = new Array[Int](sizeInBytes)
+
+ def write(i: Int): Unit = {
+ buffer(pos) = i
+ pos = (pos + 1) % buffer.length
+ }
+
+ override def toString: String = {
+ val (end, start) = buffer.splitAt(pos)
+ val input = new java.io.InputStream {
+ val iterator = (start ++ end).iterator
+
+ def read(): Int = if (iterator.hasNext) iterator.next() else -1
+ }
+ val reader = new BufferedReader(new InputStreamReader(input))
+ val stringBuilder = new StringBuilder
+ var line = reader.readLine()
+ while (line != null) {
+ stringBuilder.append(line)
+ stringBuilder.append("\n")
+ line = reader.readLine()
+ }
+ stringBuilder.toString()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index c4a7b4441c85c..85fb923cd9bc7 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -70,12 +70,14 @@ private[spark] object XORShiftRandom {
* @param args takes one argument - the number of random numbers to generate
*/
def main(args: Array[String]): Unit = {
+ // scalastyle:off println
if (args.length != 1) {
println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
System.exit(1)
}
println(benchmark(args(0).toInt))
+ // scalastyle:on println
}
/**
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
new file mode 100644
index 0000000000000..ea8755e21eb68
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.UUID;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeExternalSorterSuite {
+
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+
+ File tempDir;
+
+ private static final class CompressStream extends AbstractFunction1 {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ tempDir = new File(Utils.createTempDir$default$1());
+ taskContext = mock(TaskContext.class);
+ when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
+ @Override
+ public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
+ }
+
+ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
+ final int[] arr = new int[] { value };
+ sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
+ }
+
+ @Test
+ public void testSortingOnlyByPrefix() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ insertNumber(sorter, 5);
+ insertNumber(sorter, 1);
+ insertNumber(sorter, 3);
+ sorter.spill();
+ insertNumber(sorter, 4);
+ sorter.spill();
+ insertNumber(sorter, 2);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(i, iter.getKeyPrefix());
+ assertEquals(4, iter.getRecordLength());
+ // TODO: read rest of value.
+ }
+
+ // TODO: test for cleanup:
+ // assert(tempDir.isEmpty)
+ }
+
+ @Test
+ public void testSortingEmptyArrays() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(0, iter.getKeyPrefix());
+ assertEquals(0, iter.getRecordLength());
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
new file mode 100644
index 0000000000000..909500930539c
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
+ final byte[] strBytes = new byte[length];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, length);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ mock(RecordComparator.class),
+ mock(PrefixComparator.class),
+ 100);
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testSortingOnlyByIntegerPrefix() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ // Write the records into the data page:
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ }
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+ // Compute key prefixes based on the records' partition ids
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ prefixComparator, dataToSort.length);
+ // Given a page of records, insert those records into the sorter one-by-one:
+ position = dataPage.getBaseOffset();
+ for (int i = 0; i < dataToSort.length; i++) {
+ // position now points to the start of a record (which holds its length).
+ final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position);
+ final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
+ final int partitionId = hashPartitioner.getPartition(str);
+ sorter.insertRecord(address, partitionId);
+ position += 4 + recordLength;
+ }
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ int iterLength = 0;
+ long prevPrefix = -1;
+ Arrays.sort(dataToSort);
+ while (iter.hasNext()) {
+ iter.loadNext();
+ final String str =
+ getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
+ final long keyPrefix = iter.getKeyPrefix();
+ assertThat(str, isIn(Arrays.asList(dataToSort)));
+ assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
+ prevPrefix = keyPrefix;
+ iterLength++;
+ }
+ assertEquals(dataToSort.length, iterLength);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d1761a48babbc..cc50e6d79a3e2 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint()
- assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 9c191ed52206d..2300bcff4f118 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)
val thrown = intercept[SparkException] {
+ // scalastyle:off println
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
+ // scalastyle:on println
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("failed 4 times"))
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index a8c8c6f73fb5a..b099cd3fb7965 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
+ // scalastyle:off println
sc.parallelize(1 to 10, 2).foreach(x => println(a))
+ // scalastyle:on println
}
assert(thrown2.getClass === classOf[SparkException])
assert(thrown2.getMessage.contains("NotSerializableException") ||
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 6e65b0a8f6c76..876418aa13029 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext {
val textFile = new File(testTempDir, "FileServerSuite.txt")
val pw = new PrintWriter(textFile)
+ // scalastyle:off println
pw.println("100")
+ // scalastyle:on println
pw.close()
val jarFile = new File(testTempDir, "test.jar")
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 911b3bddd1836..b31b09196608f 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -17,64 +17,145 @@
package org.apache.spark
-import scala.concurrent.duration._
import scala.language.postfixOps
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.storage.BlockManagerId
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.mockito.Mockito.{mock, spy, verify, when}
import org.mockito.Matchers
import org.mockito.Matchers._
-import org.apache.spark.scheduler.TaskScheduler
-import org.apache.spark.util.RpcUtils
-import org.scalatest.concurrent.Eventually._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.ManualClock
-class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {
+class HeartbeatReceiverSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with PrivateMethodTester
+ with LocalSparkContext {
- test("HeartbeatReceiver") {
+ private val executorId1 = "executor-1"
+ private val executorId2 = "executor-2"
+
+ // Shared state that must be reset before and after each test
+ private var scheduler: TaskScheduler = null
+ private var heartbeatReceiver: HeartbeatReceiver = null
+ private var heartbeatReceiverRef: RpcEndpointRef = null
+ private var heartbeatReceiverClock: ManualClock = null
+
+ override def beforeEach(): Unit = {
sc = spy(new SparkContext("local[2]", "test"))
- val scheduler = mock(classOf[TaskScheduler])
- when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ scheduler = mock(classOf[TaskScheduler])
when(sc.taskScheduler).thenReturn(scheduler)
+ heartbeatReceiverClock = new ManualClock
+ heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
+ heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ }
- val heartbeatReceiver = new HeartbeatReceiver(sc)
- sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(heartbeatReceiver.scheduler != null)
- }
- val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+ override def afterEach(): Unit = {
+ resetSparkContext()
+ scheduler = null
+ heartbeatReceiver = null
+ heartbeatReceiverRef = null
+ heartbeatReceiverClock = null
+ }
- val metrics = new TaskMetrics
- val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
- val response = receiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+ test("task scheduler is set correctly") {
+ assert(heartbeatReceiver.scheduler === null)
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ assert(heartbeatReceiver.scheduler !== null)
+ }
- verify(scheduler).executorHeartbeatReceived(
- Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
- assert(false === response.reregisterBlockManager)
+ test("normal heartbeat") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = false)
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 2)
+ assert(trackedExecutors.contains(executorId1))
+ assert(trackedExecutors.contains(executorId2))
}
- test("HeartbeatReceiver re-register") {
- sc = spy(new SparkContext("local[2]", "test"))
- val scheduler = mock(classOf[TaskScheduler])
- when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
- when(sc.taskScheduler).thenReturn(scheduler)
+ test("reregister if scheduler is not ready yet") {
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ // Task scheduler not set in HeartbeatReceiver
+ triggerHeartbeat(executorId1, executorShouldReregister = true)
+ }
- val heartbeatReceiver = new HeartbeatReceiver(sc)
- sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
- eventually(timeout(5 seconds), interval(5 millis)) {
- assert(heartbeatReceiver.scheduler != null)
- }
- val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+ test("reregister if heartbeat from unregistered executor") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ // Received heartbeat from unknown receiver, so we ask it to re-register
+ triggerHeartbeat(executorId1, executorShouldReregister = true)
+ assert(executorLastSeen(heartbeatReceiver).isEmpty)
+ }
+
+ test("reregister if heartbeat from removed executor") {
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ // Remove the second executor but not the first
+ heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy"))
+ // Now trigger the heartbeats
+ // A heartbeat from the second executor should require reregistering
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = true)
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 1)
+ assert(trackedExecutors.contains(executorId1))
+ assert(!trackedExecutors.contains(executorId2))
+ }
+ test("expire dead hosts") {
+ val executorTimeout = executorTimeoutMs(heartbeatReceiver)
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = false)
+ // Advance the clock and only trigger a heartbeat for the first executor
+ heartbeatReceiverClock.advance(executorTimeout / 2)
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ heartbeatReceiverClock.advance(executorTimeout)
+ heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
+ // Only the second executor should be expired as a dead host
+ verify(scheduler).executorLost(Matchers.eq(executorId2), any())
+ val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ assert(trackedExecutors.size === 1)
+ assert(trackedExecutors.contains(executorId1))
+ assert(!trackedExecutors.contains(executorId2))
+ }
+
+ /** Manually send a heartbeat and return the response. */
+ private def triggerHeartbeat(
+ executorId: String,
+ executorShouldReregister: Boolean): Unit = {
val metrics = new TaskMetrics
- val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
- val response = receiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+ val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
+ val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
+ Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
+ if (executorShouldReregister) {
+ assert(response.reregisterBlockManager)
+ } else {
+ assert(!response.reregisterBlockManager)
+ // Additionally verify that the scheduler callback is called with the correct parameters
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ }
+ }
- verify(scheduler).executorHeartbeatReceived(
- Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
- assert(true === response.reregisterBlockManager)
+ // Helper methods to access private fields in HeartbeatReceiver
+ private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
+ private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
+ private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
+ receiver invokePrivate _executorLastSeen()
+ }
+ private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
+ receiver invokePrivate _executorTimeoutMs()
}
+
}
diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
index 376481ba541fa..25b79bce6ab98 100644
--- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
+import javax.net.ssl.SSLContext
import com.google.common.io.Files
import org.apache.spark.util.Utils
@@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath
val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath
+ // Pick two cipher suites that the provider knows about
+ val sslContext = SSLContext.getInstance("TLSv1.2")
+ sslContext.init(null, null, null)
+ val algorithms = sslContext
+ .getServerSocketFactory
+ .getDefaultCipherSuites
+ .take(2)
+ .toSet
+
val conf = new SparkConf
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.keyStore", keyStorePath)
@@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA")
- conf.set("spark.ssl.protocol", "SSLv3")
+ conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(","))
+ conf.set("spark.ssl.protocol", "TLSv1.2")
val opts = SSLOptions.parse(conf, "spark.ssl")
@@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(opts.trustStorePassword === Some("password"))
assert(opts.keyStorePassword === Some("password"))
assert(opts.keyPassword === Some("password"))
- assert(opts.protocol === Some("SSLv3"))
- assert(opts.enabledAlgorithms ===
- Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"))
+ assert(opts.protocol === Some("TLSv1.2"))
+ assert(opts.enabledAlgorithms === algorithms)
}
test("test resolving property with defaults specified ") {
diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
index 1a099da2c6c8e..33270bec6247c 100644
--- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
+++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
@@ -25,6 +25,20 @@ object SSLSampleConfigs {
this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath
val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath
+ val enabledAlgorithms =
+ // A reasonable set of TLSv1.2 Oracle security provider suites
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
+ "TLS_RSA_WITH_AES_256_CBC_SHA256, " +
+ "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ // and their equivalent names in the IBM Security provider
+ "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
+ "SSL_RSA_WITH_AES_256_CBC_SHA256, " +
+ "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
+ "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
+ "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256"
+
def sparkSSLConfig(): SparkConf = {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.ssl.enabled", "true")
@@ -33,9 +47,8 @@ object SSLSampleConfigs {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA")
- conf.set("spark.ssl.protocol", "TLSv1")
+ conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
+ conf.set("spark.ssl.protocol", "TLSv1.2")
conf
}
@@ -47,9 +60,8 @@ object SSLSampleConfigs {
conf.set("spark.ssl.keyPassword", "password")
conf.set("spark.ssl.trustStore", trustStorePath)
conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms",
- "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA")
- conf.set("spark.ssl.protocol", "TLSv1")
+ conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
+ conf.set("spark.ssl.protocol", "TLSv1.2")
conf
}
diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
index e9b64aa82a17a..f34aefca4eb18 100644
--- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
@@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite {
test("ssl on setup") {
val conf = SSLSampleConfigs.sparkSSLConfig()
+ val expectedAlgorithms = Set(
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
+ "TLS_RSA_WITH_AES_256_CBC_SHA256",
+ "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
+ "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
+ "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
+ "SSL_RSA_WITH_AES_256_CBC_SHA256",
+ "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256",
+ "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
+ "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256")
val securityManager = new SecurityManager(conf)
@@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite {
assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.fileServerSSLOptions.keyPassword === Some("password"))
- assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1"))
- assert(securityManager.fileServerSSLOptions.enabledAlgorithms ===
- Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))
+ assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2"))
+ assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms)
assert(securityManager.akkaSSLOptions.trustStore.isDefined === true)
assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore")
@@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite {
assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password"))
assert(securityManager.akkaSSLOptions.keyPassword === Some("password"))
- assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1"))
- assert(securityManager.akkaSSLOptions.enabledAlgorithms ===
- Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA"))
+ assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2"))
+ assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms)
}
test("ssl off setup") {
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 9fbaeb33f97cd..90cb7da94e88a 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -260,10 +260,10 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
assert(RpcUtils.retryWaitMs(conf) === 2L)
conf.set("spark.akka.askTimeout", "3")
- assert(RpcUtils.askTimeout(conf) === (3 seconds))
+ assert(RpcUtils.askRpcTimeout(conf).duration === (3 seconds))
conf.set("spark.akka.lookupTimeout", "4")
- assert(RpcUtils.lookupTimeout(conf) === (4 seconds))
+ assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds))
}
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 6838b35ab4cc8..5c57940fa5f77 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.util.Utils
import scala.concurrent.Await
import scala.concurrent.duration.Duration
+import org.scalatest.Matchers._
class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
@@ -272,4 +273,16 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
sc.stop()
}
}
+
+ test("calling multiple sc.stop() must not throw any exception") {
+ noException should be thrownBy {
+ sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
+ val cnt = sc.parallelize(1 to 4).count()
+ sc.cancelAllJobs()
+ sc.stop()
+ // call stop second time
+ sc.stop()
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
index 6580139df6c60..48509f0759a3b 100644
--- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
@@ -36,7 +36,7 @@ object ThreadingSuiteState {
}
}
-class ThreadingSuite extends SparkFunSuite with LocalSparkContext {
+class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
@@ -130,8 +130,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext {
Thread.sleep(100)
}
if (running.get() != 4) {
- println("Waited 1 second without seeing runningThreads = 4 (it was " +
- running.get() + "); failing test")
ThreadingSuiteState.failed.set(true)
}
number
@@ -143,6 +141,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext {
}
sem.acquire(2)
if (ThreadingSuiteState.failed.get()) {
+ logError("Waited 1 second without seeing runningThreads = 4 (it was " +
+ ThreadingSuiteState.runningThreads.get() + "); failing test")
fail("One or more threads didn't see runningThreads = 4")
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 357ed90be3f5c..e7878bde6fcb0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -51,9 +51,11 @@ class SparkSubmitSuite
/** Simple PrintStream that reads data into a buffer */
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
var lineBuffer = ArrayBuffer[String]()
+ // scalastyle:off println
override def println(line: String) {
lineBuffer += line
}
+ // scalastyle:on println
}
/** Returns true if the script exits and the given search string is printed. */
@@ -81,6 +83,7 @@ class SparkSubmitSuite
}
}
+ // scalastyle:off println
test("prints usage on empty input") {
testPrematureExit(Array[String](), "Usage: spark-submit")
}
@@ -243,7 +246,7 @@ class SparkSubmitSuite
mainClass should be ("org.apache.spark.deploy.Client")
}
classpath should have size 0
- sysProps should have size 8
+ sysProps should have size 9
sysProps.keys should contain ("SPARK_SUBMIT")
sysProps.keys should contain ("spark.master")
sysProps.keys should contain ("spark.app.name")
@@ -252,6 +255,7 @@ class SparkSubmitSuite
sysProps.keys should contain ("spark.driver.cores")
sysProps.keys should contain ("spark.driver.supervise")
sysProps.keys should contain ("spark.shuffle.spill")
+ sysProps.keys should contain ("spark.submit.deployMode")
sysProps("spark.shuffle.spill") should be ("false")
}
@@ -491,6 +495,7 @@ class SparkSubmitSuite
appArgs.executorMemory should be ("2.3g")
}
}
+ // scalastyle:on println
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -548,6 +553,7 @@ object JarCreationTest extends Logging {
if (result.nonEmpty) {
throw new Exception("Could not load user class from jar:\n" + result(0))
}
+ sc.stop()
}
}
@@ -573,6 +579,7 @@ object SimpleApplicationTest {
s"Master had $config=$masterValue but executor had $config=$executorValue")
}
}
+ sc.stop()
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index 12c40f0b7d658..01ece1a10f46d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
/** Simple PrintStream that reads data into a buffer */
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
var lineBuffer = ArrayBuffer[String]()
+ // scalastyle:off println
override def println(line: String) {
lineBuffer += line
}
+ // scalastyle:on println
}
override def beforeAll() {
@@ -77,9 +79,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(resolver2.getResolvers.size() === 7)
val expected = repos.split(",").map(r => s"$r/")
resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) =>
- if (i > 3) {
- assert(resolver.getName === s"repo-${i - 3}")
- assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4))
+ if (i < 3) {
+ assert(resolver.getName === s"repo-${i + 1}")
+ assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i))
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 09075eeb539aa..2a62450bcdbad 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -39,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils}
class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
+ import FsHistoryProvider._
+
private var testDir: File = null
before {
@@ -67,7 +69,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
// Write a new-style application log.
val newAppComplete = newLogFile("new1", None, inProgress = false)
writeFile(newAppComplete, true, None,
- SparkListenerApplicationStart("new-app-complete", None, 1L, "test", None),
+ SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test",
+ None),
SparkListenerApplicationEnd(5L)
)
@@ -75,35 +78,30 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false,
Some("lzf"))
writeFile(newAppCompressedComplete, true, None,
- SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test", None),
+ SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"),
+ 1L, "test", None),
SparkListenerApplicationEnd(4L))
// Write an unfinished app, new-style.
val newAppIncomplete = newLogFile("new2", None, inProgress = true)
writeFile(newAppIncomplete, true, None,
- SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test", None)
+ SparkListenerApplicationStart(newAppIncomplete.getName(), Some("new-incomplete"), 1L, "test",
+ None)
)
// Write an old-style application log.
- val oldAppComplete = new File(testDir, "old1")
- oldAppComplete.mkdir()
- createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0"))
- writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None,
- SparkListenerApplicationStart("old-app-complete", None, 2L, "test", None),
+ val oldAppComplete = writeOldLog("old1", "1.0", None, true,
+ SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None),
SparkListenerApplicationEnd(3L)
)
- createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE))
// Check for logs so that we force the older unfinished app to be loaded, to make
// sure unfinished apps are also sorted correctly.
provider.checkForLogs()
// Write an unfinished app, old-style.
- val oldAppIncomplete = new File(testDir, "old2")
- oldAppIncomplete.mkdir()
- createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0"))
- writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None,
- SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test", None)
+ val oldAppIncomplete = writeOldLog("old2", "1.0", None, false,
+ SparkListenerApplicationStart("old2", None, 2L, "test", None)
)
// Force a reload of data from the log directory, and check that both logs are loaded.
@@ -124,16 +122,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed)))
}
- list(0) should be (makeAppInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L,
+ list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L,
newAppComplete.lastModified(), "test", true))
- list(1) should be (makeAppInfo(newAppCompressedComplete.getName(),
- "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test",
- true))
- list(2) should be (makeAppInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L,
+ list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(),
+ 1L, 4L, newAppCompressedComplete.lastModified(), "test", true))
+ list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L,
oldAppComplete.lastModified(), "test", true))
- list(3) should be (makeAppInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, -1L,
- oldAppIncomplete.lastModified(), "test", false))
- list(4) should be (makeAppInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, -1L,
+ list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L,
+ -1L, oldAppIncomplete.lastModified(), "test", false))
+ list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L,
newAppIncomplete.lastModified(), "test", false))
// Make sure the UI can be rendered.
@@ -155,12 +152,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null
val logDir = new File(testDir, codecName)
logDir.mkdir()
- createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0"))
- writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec),
+ createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0"))
+ writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec),
SparkListenerApplicationStart("app2", None, 2L, "test", None),
SparkListenerApplicationEnd(3L)
)
- createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName))
+ createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName))
val logPath = new Path(logDir.getAbsolutePath())
try {
@@ -180,12 +177,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
test("SPARK-3697: ignore directories that cannot be read.") {
val logFile1 = newLogFile("new1", None, inProgress = false)
writeFile(logFile1, true, None,
- SparkListenerApplicationStart("app1-1", None, 1L, "test", None),
+ SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None),
SparkListenerApplicationEnd(2L)
)
val logFile2 = newLogFile("new2", None, inProgress = false)
writeFile(logFile2, true, None,
- SparkListenerApplicationStart("app1-2", None, 1L, "test", None),
+ SparkListenerApplicationStart("app1-2", Some("app1-2"), 1L, "test", None),
SparkListenerApplicationEnd(2L)
)
logFile2.setReadable(false, false)
@@ -218,6 +215,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
}
}
+ test("Parse logs that application is not started") {
+ val provider = new FsHistoryProvider((createTestConf()))
+
+ val logFile1 = newLogFile("app1", None, inProgress = true)
+ writeFile(logFile1, true, None,
+ SparkListenerLogStart("1.4")
+ )
+ updateAndCheck(provider) { list =>
+ list.size should be (0)
+ }
+ }
+
test("SPARK-5582: empty log directory") {
val provider = new FsHistoryProvider(createTestConf())
@@ -373,6 +382,33 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
}
}
+ test("SPARK-8372: new logs with no app ID are ignored") {
+ val provider = new FsHistoryProvider(createTestConf())
+
+ // Write a new log file without an app id, to make sure it's ignored.
+ val logFile1 = newLogFile("app1", None, inProgress = true)
+ writeFile(logFile1, true, None,
+ SparkListenerLogStart("1.4")
+ )
+
+ // Write a 1.2 log file with no start event (= no app id), it should be ignored.
+ writeOldLog("v12Log", "1.2", None, false)
+
+ // Write 1.0 and 1.1 logs, which don't have app ids.
+ writeOldLog("v11Log", "1.1", None, true,
+ SparkListenerApplicationStart("v11Log", None, 2L, "test", None),
+ SparkListenerApplicationEnd(3L))
+ writeOldLog("v10Log", "1.0", None, true,
+ SparkListenerApplicationStart("v10Log", None, 2L, "test", None),
+ SparkListenerApplicationEnd(4L))
+
+ updateAndCheck(provider) { list =>
+ list.size should be (2)
+ list(0).id should be ("v10Log")
+ list(1).id should be ("v11Log")
+ }
+ }
+
/**
* Asks the provider to check for logs and calls a function to perform checks on the updated
* app list. Example:
@@ -412,4 +448,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
}
+ private def writeOldLog(
+ fname: String,
+ sparkVersion: String,
+ codec: Option[CompressionCodec],
+ completed: Boolean,
+ events: SparkListenerEvent*): File = {
+ val log = new File(testDir, fname)
+ log.mkdir()
+
+ val oldEventLog = new File(log, LOG_PREFIX + "1")
+ createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion))
+ writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*)
+ if (completed) {
+ createEmptyFile(new File(log, APPLICATION_COMPLETE))
+ }
+
+ log
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 014e87bb40254..9cb6dd43bac47 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -19,63 +19,21 @@ package org.apache.spark.deploy.master
import java.util.Date
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.io.Source
import scala.language.postfixOps
-import akka.actor.Address
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually
import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy._
class MasterSuite extends SparkFunSuite with Matchers with Eventually {
- test("toAkkaUrl") {
- val conf = new SparkConf(loadDefaults = false)
- val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp")
- assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
- }
-
- test("toAkkaUrl with SSL") {
- val conf = new SparkConf(loadDefaults = false)
- val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp")
- assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl)
- }
-
- test("toAkkaUrl: a typo url") {
- val conf = new SparkConf(loadDefaults = false)
- val e = intercept[SparkException] {
- Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp")
- }
- assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
- }
-
- test("toAkkaAddress") {
- val conf = new SparkConf(loadDefaults = false)
- val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp")
- assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
- }
-
- test("toAkkaAddress with SSL") {
- val conf = new SparkConf(loadDefaults = false)
- val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp")
- assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address)
- }
-
- test("toAkkaAddress: a typo url") {
- val conf = new SparkConf(loadDefaults = false)
- val e = intercept[SparkException] {
- Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp")
- }
- assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
- }
-
test("can use a custom recovery mode factory") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.deploy.recoveryMode", "CUSTOM")
@@ -129,16 +87,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
port = 10000,
cores = 0,
memory = 0,
- actor = null,
+ endpoint = null,
webUiPort = 0,
publicAddress = ""
)
- val (actorSystem, port, uiPort, restPort) =
- Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf)
+ val (rpcEnv, uiPort, restPort) =
+ Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf)
try {
- Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds)
+ rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME)
CustomPersistenceEngine.lastInstance.isDefined shouldBe true
val persistenceEngine = CustomPersistenceEngine.lastInstance.get
@@ -154,8 +112,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
workers.map(_.id) should contain(workerToPersist.id)
} finally {
- actorSystem.shutdown()
- actorSystem.awaitTermination()
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
}
CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index 197f68e7ec5ed..96e456d889ac3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse
import scala.collection.mutable
-import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import com.google.common.base.Charsets
import org.scalatest.BeforeAndAfterEach
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods._
import org.apache.spark._
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.rpc._
+import org.apache.spark.util.Utils
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
import org.apache.spark.deploy.master.DriverState._
@@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._
* Tests for the REST application submission protocol used in standalone cluster mode.
*/
class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
- private var actorSystem: Option[ActorSystem] = None
+ private var rpcEnv: Option[RpcEnv] = None
private var server: Option[RestSubmissionServer] = None
override def afterEach() {
- actorSystem.foreach(_.shutdown())
+ rpcEnv.foreach(_.shutdown())
server.foreach(_.stop())
}
@@ -377,31 +377,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
killMessage: String = "driver is killed",
state: DriverState = FINISHED,
exception: Option[Exception] = None): String = {
- startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception))
+ startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception))
}
/** Start a smarter dummy server that keeps track of submitted driver states. */
private def startSmartServer(): String = {
- startServer(new SmarterMaster)
+ startServer(new SmarterMaster(_))
}
/** Start a dummy server that is faulty in many ways... */
private def startFaultyServer(): String = {
- startServer(new DummyMaster, faulty = true)
+ startServer(new DummyMaster(_), faulty = true)
}
/**
- * Start a [[StandaloneRestServer]] that communicates with the given actor.
+ * Start a [[StandaloneRestServer]] that communicates with the given endpoint.
* If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead.
* Return the master URL that corresponds to the address of this server.
*/
- private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = {
+ private def startServer(
+ makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = {
val name = "test-standalone-rest-protocol"
val conf = new SparkConf
val localhost = Utils.localHostName()
val securityManager = new SecurityManager(conf)
- val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager)
- val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster))
+ val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager)
+ val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv))
val _server =
if (faulty) {
new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")
@@ -410,7 +411,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
}
val port = _server.start()
// set these to clean them up after every test
- actorSystem = Some(_actorSystem)
+ rpcEnv = Some(_rpcEnv)
server = Some(_server)
s"spark://$localhost:$port"
}
@@ -505,20 +506,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
* In all responses, the success parameter is always true.
*/
private class DummyMaster(
+ override val rpcEnv: RpcEnv,
submitId: String = "fake-driver-id",
submitMessage: String = "submitted",
killMessage: String = "killed",
state: DriverState = FINISHED,
exception: Option[Exception] = None)
- extends Actor {
+ extends RpcEndpoint {
- override def receive: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestSubmitDriver(driverDesc) =>
- sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage)
+ context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage))
case RequestKillDriver(driverId) =>
- sender ! KillDriverResponse(driverId, success = true, killMessage)
+ context.reply(KillDriverResponse(self, driverId, success = true, killMessage))
case RequestDriverStatus(driverId) =>
- sender ! DriverStatusResponse(found = true, Some(state), None, None, exception)
+ context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception))
}
}
@@ -531,28 +533,28 @@ private class DummyMaster(
* Submits are always successful while kills and status requests are successful only
* if the driver was submitted in the past.
*/
-private class SmarterMaster extends Actor {
+private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
private var counter: Int = 0
private val submittedDrivers = new mutable.HashMap[String, DriverState]
- override def receive: PartialFunction[Any, Unit] = {
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestSubmitDriver(driverDesc) =>
val driverId = s"driver-$counter"
submittedDrivers(driverId) = RUNNING
counter += 1
- sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted")
+ context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted"))
case RequestKillDriver(driverId) =>
val success = submittedDrivers.contains(driverId)
if (success) {
submittedDrivers(driverId) = KILLED
}
- sender ! KillDriverResponse(driverId, success, "killed")
+ context.reply(KillDriverResponse(self, driverId, success, "killed"))
case RequestDriverStatus(driverId) =>
val found = submittedDrivers.contains(driverId)
val state = submittedDrivers.get(driverId)
- sender ! DriverStatusResponse(found, state, None, None, None)
+ context.reply(DriverStatusResponse(found, state, None, None, None))
}
}
@@ -568,7 +570,7 @@ private class FaultyStandaloneRestServer(
host: String,
requestedPort: Int,
masterConf: SparkConf,
- masterActor: ActorRef,
+ masterEndpoint: RpcEndpointRef,
masterUrl: String)
extends RestSubmissionServer(host, requestedPort, masterConf) {
@@ -578,7 +580,7 @@ private class FaultyStandaloneRestServer(
/** A faulty servlet that produces malformed responses. */
class MalformedSubmitServlet
- extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) {
+ extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) {
protected override def sendResponse(
responseMessage: SubmitRestProtocolResponse,
responseServlet: HttpServletResponse): Unit = {
@@ -588,7 +590,7 @@ private class FaultyStandaloneRestServer(
}
/** A faulty servlet that produces invalid responses. */
- class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) {
+ class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) {
protected override def handleKill(submissionId: String): KillSubmissionResponse = {
val k = super.handleKill(submissionId)
k.submissionId = null
@@ -597,7 +599,7 @@ private class FaultyStandaloneRestServer(
}
/** A faulty status servlet that explodes. */
- class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) {
+ class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) {
private def explode: Int = 1 / 0
protected override def handleStatus(submissionId: String): SubmissionStatusResponse = {
val s = super.handleStatus(submissionId)
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala
index 115ac0534a1b4..725b8848bc052 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.deploy.rest
import java.lang.Boolean
-import java.lang.Integer
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.util.Utils
/**
* Tests for the REST application submission protocol.
@@ -93,7 +93,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite {
// optional fields
conf.set("spark.jars", "mayonnaise.jar,ketchup.jar")
conf.set("spark.files", "fireball.png")
- conf.set("spark.driver.memory", "512m")
+ conf.set("spark.driver.memory", s"${Utils.DEFAULT_DRIVER_MEM_MB}m")
conf.set("spark.driver.cores", "180")
conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red")
conf.set("spark.driver.extraClassPath", "food-coloring.jar")
@@ -126,7 +126,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite {
assert(newMessage.sparkProperties("spark.app.name") === "SparkPie")
assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar")
assert(newMessage.sparkProperties("spark.files") === "fireball.png")
- assert(newMessage.sparkProperties("spark.driver.memory") === "512m")
+ assert(newMessage.sparkProperties("spark.driver.memory") === s"${Utils.DEFAULT_DRIVER_MEM_MB}m")
assert(newMessage.sparkProperties("spark.driver.cores") === "180")
assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") ===
" -Dslices=5 -Dcolor=mostly_red")
@@ -230,7 +230,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite {
""".stripMargin
private val submitDriverRequestJson =
- """
+ s"""
|{
| "action" : "CreateSubmissionRequest",
| "appArgs" : [ "two slices", "a hint of cinnamon" ],
@@ -246,7 +246,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite {
| "spark.driver.supervise" : "false",
| "spark.app.name" : "SparkPie",
| "spark.cores.max" : "10000",
- | "spark.driver.memory" : "512m",
+ | "spark.driver.memory" : "${Utils.DEFAULT_DRIVER_MEM_MB}m",
| "spark.files" : "fireball.png",
| "spark.driver.cores" : "180",
| "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red",
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
index ac18f04a11475..cd24d79423316 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.deploy.worker
-import akka.actor.AddressFromURIString
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.SecurityManager
import org.apache.spark.rpc.{RpcAddress, RpcEnv}
@@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite {
test("WorkerWatcher shuts down on valid disassociation") {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
- val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
- val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
+ val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
- workerWatcher.onDisconnected(
- RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get))
+ workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234))
assert(workerWatcher.isShutDown)
rpcEnv.shutdown()
}
@@ -40,13 +37,13 @@ class WorkerWatcherSuite extends SparkFunSuite {
test("WorkerWatcher stays alive on invalid disassociation") {
val conf = new SparkConf()
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
- val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker"
- val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor"
- val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
+ val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker")
+ val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor"
+ val otherAkkaAddress = RpcAddress("4.3.2.1", 1234)
val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl)
workerWatcher.setTesting(testing = true)
rpcEnv.setupEndpoint("worker-watcher", workerWatcher)
- workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get))
+ workerWatcher.onDisconnected(otherAkkaAddress)
assert(!workerWatcher.isShutDown)
rpcEnv.shutdown()
}
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
index 63947df3d43a2..8a199459c1ddf 100644
--- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.hadoop.io.Text
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils
import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec}
@@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi
* [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary
* directory is created as fake input. Temporal storage would be deleted in the end.
*/
-class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll {
+class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
private var sc: SparkContext = _
private var factory: CompressionCodecFactory = _
@@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl
*/
test("Correctness of WholeTextFileRecordReader.") {
val dir = Utils.createTempDir()
- println(s"Local disk address is ${dir.toString}.")
+ logInfo(s"Local disk address is ${dir.toString}.")
WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents, false)
@@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl
test("Correctness of WholeTextFileRecordReader with GzipCodec.") {
val dir = Utils.createTempDir()
- println(s"Local disk address is ${dir.toString}.")
+ logInfo(s"Local disk address is ${dir.toString}.")
WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents, true)
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 9e4d34fb7d382..d3218a548efc7 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt")
val pw = new PrintWriter(new FileWriter(tmpFile))
for (x <- 1 to numRecords) {
+ // scalastyle:off println
pw.println(RandomUtils.nextInt(0, numBuckets))
+ // scalastyle:on println
}
pw.close()
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala
new file mode 100644
index 0000000000000..b3223ec61bf79
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+
+class RpcAddressSuite extends SparkFunSuite {
+
+ test("hostPort") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.host == "1.2.3.4")
+ assert(address.port == 1234)
+ assert(address.hostPort == "1.2.3.4:1234")
+ }
+
+ test("fromSparkURL") {
+ val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234")
+ assert(address.host == "1.2.3.4")
+ assert(address.port == 1234)
+ }
+
+ test("fromSparkURL: a typo url") {
+ val e = intercept[SparkException] {
+ RpcAddress.fromSparkURL("spark://1.2. 3.4:1234")
+ }
+ assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
+ }
+
+ test("fromSparkURL: invalid scheme") {
+ val e = intercept[SparkException] {
+ RpcAddress.fromSparkURL("invalid://1.2.3.4:1234")
+ }
+ assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage)
+ }
+
+ test("toSparkURL") {
+ val address = RpcAddress("1.2.3.4", 1234)
+ assert(address.toSparkURL == "spark://1.2.3.4:1234")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 1f0aa759b08da..6ceafe4337747 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -155,16 +155,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
})
val conf = new SparkConf()
+ val shortProp = "spark.rpc.short.timeout"
conf.set("spark.rpc.retry.wait", "0")
conf.set("spark.rpc.numRetries", "1")
val anotherEnv = createRpcEnv(conf, "remote", 13345)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
try {
- val e = intercept[Exception] {
- rpcEndpointRef.askWithRetry[String]("hello", 1 millis)
+ // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause
+ val e = intercept[SparkException] {
+ rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp))
}
- assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException])
+ // The SparkException cause should be a RpcTimeoutException with message indicating the
+ // controlling timeout property
+ assert(e.getCause.isInstanceOf[RpcTimeoutException])
+ assert(e.getCause.getMessage.contains(shortProp))
} finally {
anotherEnv.shutdown()
anotherEnv.awaitTermination()
@@ -539,6 +544,92 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
+ test("construct RpcTimeout with conf property") {
+ val conf = new SparkConf
+
+ val testProp = "spark.ask.test.timeout"
+ val testDurationSeconds = 30
+ val secondaryProp = "spark.ask.secondary.timeout"
+
+ conf.set(testProp, s"${testDurationSeconds}s")
+ conf.set(secondaryProp, "100s")
+
+ // Construct RpcTimeout with a single property
+ val rt1 = RpcTimeout(conf, testProp)
+ assert( testDurationSeconds === rt1.duration.toSeconds )
+
+ // Construct RpcTimeout with prioritized list of properties
+ val rt2 = RpcTimeout(conf, Seq("spark.ask.invalid.timeout", testProp, secondaryProp), "1s")
+ assert( testDurationSeconds === rt2.duration.toSeconds )
+
+ // Construct RpcTimeout with default value,
+ val defaultProp = "spark.ask.default.timeout"
+ val defaultDurationSeconds = 1
+ val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s")
+ assert( defaultDurationSeconds === rt3.duration.toSeconds )
+ assert( rt3.timeoutProp.contains(defaultProp) )
+
+ // Try to construct RpcTimeout with an unconfigured property
+ intercept[NoSuchElementException] {
+ RpcTimeout(conf, "spark.ask.invalid.timeout")
+ }
+ }
+
+ test("ask a message timeout on Future using RpcTimeout") {
+ case class NeverReply(msg: String)
+
+ val rpcEndpointRef = env.setupEndpoint("ask-future", new RpcEndpoint {
+ override val rpcEnv = env
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case msg: String => context.reply(msg)
+ case _: NeverReply =>
+ }
+ })
+
+ val longTimeout = new RpcTimeout(1 second, "spark.rpc.long.timeout")
+ val shortTimeout = new RpcTimeout(10 millis, "spark.rpc.short.timeout")
+
+ // Ask with immediate response, should complete successfully
+ val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout)
+ val reply1 = longTimeout.awaitResult(fut1)
+ assert("hello" === reply1)
+
+ // Ask with a delayed response and wait for response immediately that should timeout
+ val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout)
+ val reply2 =
+ intercept[RpcTimeoutException] {
+ shortTimeout.awaitResult(fut2)
+ }.getMessage
+
+ // RpcTimeout.awaitResult should have added the property to the TimeoutException message
+ assert(reply2.contains(shortTimeout.timeoutProp))
+
+ // Ask with delayed response and allow the Future to timeout before Await.result
+ val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout)
+
+ // Allow future to complete with failure using plain Await.result, this will return
+ // once the future is complete to verify addMessageIfTimeout was invoked
+ val reply3 =
+ intercept[RpcTimeoutException] {
+ Await.result(fut3, 200 millis)
+ }.getMessage
+
+ // When the future timed out, the recover callback should have used
+ // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message
+ assert(reply3.contains(shortTimeout.timeoutProp))
+
+ // Use RpcTimeout.awaitResult to process Future, since it has already failed with
+ // RpcTimeoutException, the same RpcTimeoutException should be thrown
+ val reply4 =
+ intercept[RpcTimeoutException] {
+ shortTimeout.awaitResult(fut3)
+ }.getMessage
+
+ // Ensure description is not in message twice after addMessageIfTimeout and awaitResult
+ assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
+ }
+
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index a33a83db7bc9e..4aa75c9230b2c 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.rpc.akka
import org.apache.spark.rpc._
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf}
class AkkaRpcEnvSuite extends RpcEnvSuite {
@@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
}
}
+ test("uriOf") {
+ val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
+ assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
+ }
+
+ test("uriOf: ssl") {
+ val conf = SSLSampleConfigs.sparkSSLConfig()
+ val securityManager = new SecurityManager(conf)
+ val rpcEnv = new AkkaRpcEnvFactory().create(
+ RpcEnvConfig(conf, "test", "localhost", 12346, securityManager))
+ try {
+ val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
+ assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
+ } finally {
+ rpcEnv.shutdown()
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index ff3fa95ec32ae..4e3defb43a021 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter {
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
125L, "Mickey", None)
val applicationEnd = SparkListenerApplicationEnd(1000L)
+ // scalastyle:off println
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd))))
+ // scalastyle:on println
writer.close()
val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
new file mode 100644
index 0000000000000..3f1692917a357
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import java.util
+import java.util.Collections
+
+import org.apache.mesos.Protos.Value.Scalar
+import org.apache.mesos.Protos._
+import org.apache.mesos.SchedulerDriver
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.Matchers
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+
+class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
+ with LocalSparkContext
+ with MockitoSugar
+ with BeforeAndAfter {
+
+ private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = {
+ val builder = Offer.newBuilder()
+ builder.addResourcesBuilder()
+ .setName("mem")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Scalar.newBuilder().setValue(mem))
+ builder.addResourcesBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Scalar.newBuilder().setValue(cpu))
+ builder.setId(OfferID.newBuilder()
+ .setValue(offerId).build())
+ .setFrameworkId(FrameworkID.newBuilder()
+ .setValue("f1"))
+ .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
+ .setHostname(s"host${slaveId}")
+ .build()
+ }
+
+ private def createSchedulerBackend(
+ taskScheduler: TaskSchedulerImpl,
+ driver: SchedulerDriver): CoarseMesosSchedulerBackend = {
+ val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") {
+ mesosDriver = driver
+ markRegistered()
+ }
+ backend.start()
+ backend
+ }
+
+ var sparkConf: SparkConf = _
+
+ before {
+ sparkConf = (new SparkConf)
+ .setMaster("local[*]")
+ .setAppName("test-mesos-dynamic-alloc")
+ .setSparkHome("/path")
+
+ sc = new SparkContext(sparkConf)
+ }
+
+ test("mesos supports killing and limiting executors") {
+ val driver = mock[SchedulerDriver]
+ val taskScheduler = mock[TaskSchedulerImpl]
+ when(taskScheduler.sc).thenReturn(sc)
+
+ sparkConf.set("spark.driver.host", "driverHost")
+ sparkConf.set("spark.driver.port", "1234")
+
+ val backend = createSchedulerBackend(taskScheduler, driver)
+ val minMem = backend.calculateTotalMemory(sc).toInt
+ val minCpu = 4
+
+ val mesosOffers = new java.util.ArrayList[Offer]
+ mesosOffers.add(createOffer("o1", "s1", minMem, minCpu))
+
+ val taskID0 = TaskID.newBuilder().setValue("0").build()
+
+ backend.resourceOffers(driver, mesosOffers)
+ verify(driver, times(1)).launchTasks(
+ Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)),
+ any[util.Collection[TaskInfo]],
+ any[Filters])
+
+ // simulate the allocation manager down-scaling executors
+ backend.doRequestTotalExecutors(0)
+ assert(backend.doKillExecutors(Seq("s1/0")))
+ verify(driver, times(1)).killTask(taskID0)
+
+ val mesosOffers2 = new java.util.ArrayList[Offer]
+ mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu))
+ backend.resourceOffers(driver, mesosOffers2)
+
+ verify(driver, times(1))
+ .declineOffer(OfferID.newBuilder().setValue("o2").build())
+
+ // Verify we didn't launch any new executor
+ assert(backend.slaveIdsWithExecutors.size === 1)
+
+ backend.doRequestTotalExecutors(2)
+ backend.resourceOffers(driver, mesosOffers2)
+ verify(driver, times(1)).launchTasks(
+ Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)),
+ any[util.Collection[TaskInfo]],
+ any[Filters])
+
+ assert(backend.slaveIdsWithExecutors.size === 2)
+ backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build())
+ assert(backend.slaveIdsWithExecutors.size === 1)
+ }
+
+ test("mesos supports killing and relaunching tasks with executors") {
+ val driver = mock[SchedulerDriver]
+ val taskScheduler = mock[TaskSchedulerImpl]
+ when(taskScheduler.sc).thenReturn(sc)
+
+ val backend = createSchedulerBackend(taskScheduler, driver)
+ val minMem = backend.calculateTotalMemory(sc).toInt + 1024
+ val minCpu = 4
+
+ val mesosOffers = new java.util.ArrayList[Offer]
+ val offer1 = createOffer("o1", "s1", minMem, minCpu)
+ mesosOffers.add(offer1)
+
+ val offer2 = createOffer("o2", "s1", minMem, 1);
+
+ backend.resourceOffers(driver, mesosOffers)
+
+ verify(driver, times(1)).launchTasks(
+ Matchers.eq(Collections.singleton(offer1.getId)),
+ anyObject(),
+ anyObject[Filters])
+
+ // Simulate task killed, executor no longer running
+ val status = TaskStatus.newBuilder()
+ .setTaskId(TaskID.newBuilder().setValue("0").build())
+ .setSlaveId(SlaveID.newBuilder().setValue("s1").build())
+ .setState(TaskState.TASK_KILLED)
+ .build
+
+ backend.statusUpdate(driver, status)
+ assert(!backend.slaveIdsWithExecutors.contains("s1"))
+
+ mesosOffers.clear()
+ mesosOffers.add(offer2)
+ backend.resourceOffers(driver, mesosOffers)
+ assert(backend.slaveIdsWithExecutors.contains("s1"))
+
+ verify(driver, times(1)).launchTasks(
+ Matchers.eq(Collections.singleton(offer2.getId)),
+ anyObject(),
+ anyObject[Filters])
+
+ verify(driver, times(1)).reviveOffers()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala
deleted file mode 100644
index e72285d03d3ee..0000000000000
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster.mesos
-
-import org.mockito.Mockito._
-import org.scalatest.mock.MockitoSugar
-
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-
-class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar {
- test("MesosMemoryUtils should always override memoryOverhead when it's set") {
- val sparkConf = new SparkConf
-
- val sc = mock[SparkContext]
- when(sc.conf).thenReturn(sparkConf)
-
- // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896
- when(sc.executorMemory).thenReturn(512)
- assert(MemoryUtils.calculateTotalMemory(sc) === 896)
-
- // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6
- when(sc.executorMemory).thenReturn(4096)
- assert(MemoryUtils.calculateTotalMemory(sc) === 4505)
-
- // set memoryOverhead
- sparkConf.set("spark.mesos.executor.memoryOverhead", "100")
- assert(MemoryUtils.calculateTotalMemory(sc) === 4196)
- sparkConf.set("spark.mesos.executor.memoryOverhead", "400")
- assert(MemoryUtils.calculateTotalMemory(sc) === 4496)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
index 68df46a41ddc8..d01837fe78957 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
@@ -149,7 +149,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
when(sc.conf).thenReturn(new SparkConf)
when(sc.listenerBus).thenReturn(listenerBus)
- val minMem = MemoryUtils.calculateTotalMemory(sc).toInt
+ val backend = new MesosSchedulerBackend(taskScheduler, sc, "master")
+
+ val minMem = backend.calculateTotalMemory(sc)
val minCpu = 4
val mesosOffers = new java.util.ArrayList[Offer]
@@ -157,8 +159,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
mesosOffers.add(createOffer(2, minMem - 1, minCpu))
mesosOffers.add(createOffer(3, minMem, minCpu))
- val backend = new MesosSchedulerBackend(taskScheduler, sc, "master")
-
val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2)
expectedWorkerOffers.append(new WorkerOffer(
mesosOffers.get(0).getSlaveId.getValue,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
new file mode 100644
index 0000000000000..b354914b6ffd0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import org.apache.mesos.Protos.Value
+import org.mockito.Mockito._
+import org.scalatest._
+import org.scalatest.mock.MockitoSugar
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar {
+
+ // scalastyle:off structural.type
+ // this is the documented way of generating fixtures in scalatest
+ def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new {
+ val sparkConf = new SparkConf
+ val sc = mock[SparkContext]
+ when(sc.conf).thenReturn(sparkConf)
+ }
+ val utils = new MesosSchedulerUtils { }
+ // scalastyle:on structural.type
+
+ test("use at-least minimum overhead") {
+ val f = fixture
+ when(f.sc.executorMemory).thenReturn(512)
+ utils.calculateTotalMemory(f.sc) shouldBe 896
+ }
+
+ test("use overhead if it is greater than minimum value") {
+ val f = fixture
+ when(f.sc.executorMemory).thenReturn(4096)
+ utils.calculateTotalMemory(f.sc) shouldBe 4505
+ }
+
+ test("use spark.mesos.executor.memoryOverhead (if set)") {
+ val f = fixture
+ when(f.sc.executorMemory).thenReturn(1024)
+ f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512")
+ utils.calculateTotalMemory(f.sc) shouldBe 1536
+ }
+
+ test("parse a non-empty constraint string correctly") {
+ val expectedMap = Map(
+ "tachyon" -> Set("true"),
+ "zone" -> Set("us-east-1a", "us-east-1b")
+ )
+ utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap)
+ }
+
+ test("parse an empty constraint string correctly") {
+ utils.parseConstraintString("") shouldBe Map()
+ }
+
+ test("throw an exception when the input is malformed") {
+ an[IllegalArgumentException] should be thrownBy
+ utils.parseConstraintString("tachyon;zone:us-east")
+ }
+
+ test("empty values for attributes' constraints matches all values") {
+ val constraintsStr = "tachyon:"
+ val parsedConstraints = utils.parseConstraintString(constraintsStr)
+
+ parsedConstraints shouldBe Map("tachyon" -> Set())
+
+ val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build()
+ val noTachyonOffer = Map("zone" -> zoneSet)
+ val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build())
+ val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build())
+
+ utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false
+ utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true
+ utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true
+ }
+
+ test("subset match is performed for set attributes") {
+ val supersetConstraint = Map(
+ "tachyon" -> Value.Text.newBuilder().setValue("true").build(),
+ "zone" -> Value.Set.newBuilder()
+ .addItem("us-east-1a")
+ .addItem("us-east-1b")
+ .addItem("us-east-1c")
+ .build())
+
+ val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c"
+ val parsedConstraints = utils.parseConstraintString(zoneConstraintStr)
+
+ utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true
+ }
+
+ test("less than equal match is performed on scalar attributes") {
+ val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build())
+
+ val ltConstraint = utils.parseConstraintString("gpus:2")
+ val eqConstraint = utils.parseConstraintString("gpus:3")
+ val gtConstraint = utils.parseConstraintString("gpus:4")
+
+ utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true
+ utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true
+ utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false
+ }
+
+ test("contains match is performed for range attributes") {
+ val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build())
+ val ltConstraint = utils.parseConstraintString("ports:6000")
+ val eqConstraint = utils.parseConstraintString("ports:7500")
+ val gtConstraint = utils.parseConstraintString("ports:8002")
+ val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300")
+
+ utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false
+ utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true
+ utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false
+ utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true
+ }
+
+ test("equality match is performed for text attributes") {
+ val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build())
+
+ val trueConstraint = utils.parseConstraintString("tachyon:true")
+ val falseConstraint = utils.parseConstraintString("tachyon:false")
+
+ utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true
+ utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 1053c6caf7718..480722a5ac182 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -375,6 +375,7 @@ class TestCreateNullValue {
// parameters of the closure constructor. This allows us to test whether
// null values are created correctly for each type.
val nestedClosure = () => {
+ // scalastyle:off println
if (s.toString == "123") { // Don't really output them to avoid noisy
println(bo)
println(c)
@@ -389,6 +390,7 @@ class TestCreateNullValue {
val closure = () => {
println(getX)
}
+ // scalastyle:on println
ClosureCleaner.clean(closure)
}
nestedClosure()
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index a61ea3918f46a..c7638507c88c6 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -486,11 +486,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
// Test for using the util function to change our log levels.
test("log4j log level change") {
- Utils.setLogLevel(org.apache.log4j.Level.ALL)
- assert(log.isInfoEnabled())
- Utils.setLogLevel(org.apache.log4j.Level.ERROR)
- assert(!log.isInfoEnabled())
- assert(log.isErrorEnabled())
+ val current = org.apache.log4j.Logger.getRootLogger().getLevel()
+ try {
+ Utils.setLogLevel(org.apache.log4j.Level.ALL)
+ assert(log.isInfoEnabled())
+ Utils.setLogLevel(org.apache.log4j.Level.ERROR)
+ assert(!log.isInfoEnabled())
+ assert(log.isErrorEnabled())
+ } finally {
+ // Best effort at undoing changes this test made.
+ Utils.setLogLevel(current)
+ }
}
test("deleteRecursively") {
@@ -673,4 +679,14 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(!Utils.isInDirectory(nullFile, parentDir))
assert(!Utils.isInDirectory(nullFile, childFile3))
}
+
+ test("circular buffer") {
+ val buffer = new CircularBuffer(25)
+ val stream = new java.io.PrintStream(buffer, true, "UTF-8")
+
+ // scalastyle:off println
+ stream.println("test circular test circular test circular test circular test circular")
+ // scalastyle:on println
+ assert(buffer.toString === "t circular test circular\n")
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
index 5a5919fca2469..4f382414a8dd7 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
@@ -103,7 +103,9 @@ private object SizeTrackerSuite {
*/
def main(args: Array[String]): Unit = {
if (args.size < 1) {
+ // scalastyle:off println
println("Usage: SizeTrackerSuite [num elements]")
+ // scalastyle:on println
System.exit(1)
}
val numElements = args(0).toInt
@@ -180,11 +182,13 @@ private object SizeTrackerSuite {
baseTimes: Seq[Long],
sampledTimes: Seq[Long],
unsampledTimes: Seq[Long]): Unit = {
+ // scalastyle:off println
println(s"Average times for $testName (ms):")
println(" Base - " + averageTime(baseTimes))
println(" SizeTracker (sampled) - " + averageTime(sampledTimes))
println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes))
println()
+ // scalastyle:on println
}
def time(f: => Unit): Long = {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
index b2f5d9009ee5d..fefa5165db197 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.util.collection
import java.lang.{Float => JFloat, Integer => JInteger}
import java.util.{Arrays, Comparator}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.util.random.XORShiftRandom
-class SorterSuite extends SparkFunSuite {
+class SorterSuite extends SparkFunSuite with Logging {
test("equivalent to Arrays.sort") {
val rand = new XORShiftRandom(123)
@@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite {
/** Runs an experiment several times. */
def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
if (skip) {
- println(s"Skipped experiment $name.")
+ logInfo(s"Skipped experiment $name.")
return
}
@@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite {
while (i < 10) {
val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare))
next10 += time
- println(s"$name: Took $time ms")
+ logInfo(s"$name: Took $time ms")
i += 1
}
- println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
+ logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
}
/**
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
new file mode 100644
index 0000000000000..dd505dfa7d758
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import org.scalatest.prop.PropertyChecks
+
+import org.apache.spark.SparkFunSuite
+
+class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+
+ test("String prefix comparator") {
+
+ def testPrefixComparison(s1: String, s2: String): Unit = {
+ val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
+ val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+ assert(
+ (prefixComparisonResult == 0) ||
+ (prefixComparisonResult < 0 && s1 < s2) ||
+ (prefixComparisonResult > 0 && s1 > s2))
+ }
+
+ // scalastyle:off
+ val regressionTests = Table(
+ ("s1", "s2"),
+ ("abc", "世界"),
+ ("你好", "世界"),
+ ("你好123", "你好122")
+ )
+ // scalastyle:on
+
+ forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ }
+}
diff --git a/data/mllib/pic_data.txt b/data/mllib/pic_data.txt
new file mode 100644
index 0000000000000..fcfef8cd19131
--- /dev/null
+++ b/data/mllib/pic_data.txt
@@ -0,0 +1,19 @@
+0 1 1.0
+0 2 1.0
+0 3 1.0
+1 2 1.0
+1 3 1.0
+2 3 1.0
+3 4 0.1
+4 5 1.0
+4 15 1.0
+5 6 1.0
+6 7 1.0
+7 8 1.0
+8 9 1.0
+9 10 1.0
+10 11 1.0
+11 12 1.0
+12 13 1.0
+13 14 1.0
+14 15 1.0
diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
index fc03fec9866a6..61d91c70e9709 100644
--- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
+++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.util.Try
@@ -59,3 +60,4 @@ object SimpleApp {
}
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
index 0be8e64fbfabd..9f7ae75d0b477 100644
--- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
+++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.util.Try
@@ -37,3 +38,4 @@ object SimpleApp {
}
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala
index 24c7f8d667296..2f0b6ef9a5672 100644
--- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala
+++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import org.apache.spark.{SparkContext, SparkConf}
@@ -51,3 +52,4 @@ object GraphXApp {
println("Test succeeded")
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala
index 5111bc0adb772..4a980ec071ae4 100644
--- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala
+++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.collection.mutable.{ListBuffer, Queue}
@@ -55,3 +56,4 @@ object SparkSqlExample {
sc.stop()
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala
index 9f85066501472..adc25b57d6aa5 100644
--- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala
+++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.util.Try
@@ -31,3 +32,4 @@ object SimpleApp {
}
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala
index cc86ef45858c9..69c1154dc0955 100644
--- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala
+++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.collection.mutable.{ListBuffer, Queue}
@@ -57,3 +58,4 @@ object SparkSqlExample {
sc.stop()
}
}
+// scalastyle:on println
diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala
index 58a662bd9b2e8..d6a074687f4a1 100644
--- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala
+++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+// scalastyle:off println
package main.scala
import scala.collection.mutable.{ListBuffer, Queue}
@@ -61,3 +62,4 @@ object SparkStreamingExample {
ssc.stop()
}
}
+// scalastyle:on println
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 54274a83f6d66..30190dcd41ec5 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then
rm -rf $SPARK_REPO
- build/mvn -DskipTests -Pyarn -Phive \
+ build/mvn -DskipTests -Pyarn -Phive -Prelease\
-Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
clean install
./dev/change-version-to-2.11.sh
- build/mvn -DskipTests -Pyarn -Phive \
+ build/mvn -DskipTests -Pyarn -Phive -Prelease\
-Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
clean install
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index cf827ce89b857..4a17d48d8171d 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -47,6 +47,12 @@
JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "")
# ASF JIRA password
JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "")
+# OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests
+# will be unauthenticated. You should only need to configure this if you find yourself regularly
+# exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at
+# https://github.com/settings/tokens. This script only requires the "public_repo" scope.
+GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY")
+
GITHUB_BASE = "https://github.com/apache/spark/pull"
GITHUB_API_BASE = "https://api.github.com/repos/apache/spark"
@@ -58,9 +64,17 @@
def get_json(url):
try:
- return json.load(urllib2.urlopen(url))
+ request = urllib2.Request(url)
+ if GITHUB_OAUTH_KEY:
+ request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY)
+ return json.load(urllib2.urlopen(request))
except urllib2.HTTPError as e:
- print "Unable to fetch URL, exiting: %s" % url
+ if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0':
+ print "Exceeded the GitHub API rate limit; see the instructions in " + \
+ "dev/merge_spark_pr.py to configure an OAuth token for making authenticated " + \
+ "GitHub requests."
+ else:
+ print "Unable to fetch URL, exiting: %s" % url
sys.exit(-1)
diff --git a/dev/run-tests b/dev/run-tests
index a00d9f0c27639..257d1e8d50bb4 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -20,4 +20,4 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-exec python -u ./dev/run-tests.py
+exec python -u ./dev/run-tests.py "$@"
diff --git a/dev/run-tests.py b/dev/run-tests.py
index e5c897b94d167..1f0d218514f92 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import itertools
+from optparse import OptionParser
import os
import re
import sys
@@ -95,8 +96,8 @@ def determine_modules_to_test(changed_modules):
['examples', 'graphx']
>>> x = sorted(x.name for x in determine_modules_to_test([modules.sql]))
>>> x # doctest: +NORMALIZE_WHITESPACE
- ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \
- 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql']
+ ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \
+ 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql']
"""
# If we're going to have to run all of the tests, then we can just short-circuit
# and return 'root'. No module depends on root, so if it appears then it will be
@@ -292,7 +293,8 @@ def build_spark_sbt(hadoop_version):
build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
sbt_goals = ["package",
"assembly/assembly",
- "streaming-kafka-assembly/assembly"]
+ "streaming-kafka-assembly/assembly",
+ "streaming-flume-assembly/assembly"]
profiles_and_goals = build_profiles + sbt_goals
print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",
@@ -360,12 +362,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules):
run_scala_tests_sbt(test_modules, test_profiles)
-def run_python_tests(test_modules):
+def run_python_tests(test_modules, parallelism):
set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS")
command = [os.path.join(SPARK_HOME, "python", "run-tests")]
if test_modules != [modules.root]:
command.append("--modules=%s" % ','.join(m.name for m in test_modules))
+ command.append("--parallelism=%i" % parallelism)
run_cmd(command)
@@ -379,7 +382,25 @@ def run_sparkr_tests():
print("Ignoring SparkR tests as R was not found in PATH")
+def parse_opts():
+ parser = OptionParser(
+ prog="run-tests"
+ )
+ parser.add_option(
+ "-p", "--parallelism", type="int", default=4,
+ help="The number of suites to test in parallel (default %default)"
+ )
+
+ (opts, args) = parser.parse_args()
+ if args:
+ parser.error("Unsupported arguments: %s" % ' '.join(args))
+ if opts.parallelism < 1:
+ parser.error("Parallelism cannot be less than 1")
+ return opts
+
+
def main():
+ opts = parse_opts()
# Ensure the user home directory (HOME) is valid and is an absolute directory
if not USER_HOME or not os.path.isabs(USER_HOME):
print("[error] Cannot determine your home directory as an absolute path;",
@@ -461,7 +482,7 @@ def main():
modules_with_python_tests = [m for m in test_modules if m.python_test_goals]
if modules_with_python_tests:
- run_python_tests(modules_with_python_tests)
+ run_python_tests(modules_with_python_tests, opts.parallelism)
if any(m.should_run_r_tests for m in test_modules):
run_sparkr_tests()
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index efe3a897e9c10..993583e2f4119 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -203,7 +203,7 @@ def contains_file(self, filename):
streaming_flume = Module(
- name="streaming_flume",
+ name="streaming-flume",
dependencies=[streaming],
source_file_regexes=[
"external/flume",
@@ -214,6 +214,15 @@ def contains_file(self, filename):
)
+streaming_flume_assembly = Module(
+ name="streaming-flume-assembly",
+ dependencies=[streaming_flume, streaming_flume_sink],
+ source_file_regexes=[
+ "external/flume-assembly",
+ ]
+)
+
+
mllib = Module(
name="mllib",
dependencies=[streaming, sql],
@@ -241,7 +250,7 @@ def contains_file(self, filename):
pyspark_core = Module(
name="pyspark-core",
- dependencies=[mllib, streaming, streaming_kafka],
+ dependencies=[],
source_file_regexes=[
"python/(?!pyspark/(ml|mllib|sql|streaming))"
],
@@ -281,7 +290,7 @@ def contains_file(self, filename):
pyspark_streaming = Module(
name="pyspark-streaming",
- dependencies=[pyspark_core, streaming, streaming_kafka],
+ dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly],
source_file_regexes=[
"python/pyspark/streaming"
],
diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py
index ad9b0cc89e4ab..12bd0bf3a4fe9 100644
--- a/dev/sparktestsupport/shellutils.py
+++ b/dev/sparktestsupport/shellutils.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+from __future__ import print_function
import os
import shutil
import subprocess
diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile
index 5956d59130fbf..5dbdb8b22a44f 100644
--- a/docker/spark-test/base/Dockerfile
+++ b/docker/spark-test/base/Dockerfile
@@ -17,13 +17,13 @@
FROM ubuntu:precise
-RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list
-
# Upgrade package index
-RUN apt-get update
-
# install a few other useful packages plus Open Jdk 7
-RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server
+# Remove unneeded /var/lib/apt/lists/* after install to reduce the
+# docker image size (by ~30MB)
+RUN apt-get update && \
+ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \
+ rm -rf /var/lib/apt/lists/*
ENV SCALA_VERSION 2.10.4
ENV CDH_VERSION cdh4
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 6073b3626c45b..15ceda11a8a80 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -63,6 +63,51 @@
puts "cp -r " + source + "/. " + dest
cp_r(source + "/.", dest)
+
+ # Begin updating JavaDoc files for badge post-processing
+ puts "Updating JavaDoc files for badge post-processing"
+ js_script_start = ''
+
+ javadoc_files = Dir["./" + dest + "/**/*.html"]
+ javadoc_files.each do |javadoc_file|
+ # Determine file depths to reference js files
+ slash_count = javadoc_file.count "/"
+ i = 3
+ path_to_js_file = ""
+ while (i < slash_count) do
+ path_to_js_file = path_to_js_file + "../"
+ i += 1
+ end
+
+ # Create script elements to reference js files
+ javadoc_jquery_script = js_script_start + path_to_js_file + "lib/jquery" + js_script_end;
+ javadoc_api_docs_script = js_script_start + path_to_js_file + "lib/api-javadocs" + js_script_end;
+ javadoc_script_elements = javadoc_jquery_script + javadoc_api_docs_script
+
+ # Add script elements to JavaDoc files
+ javadoc_file_content = File.open(javadoc_file, "r") { |f| f.read }
+ javadoc_file_content = javadoc_file_content.sub("