diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index f12f8c275a989..18b2db69db8f1 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -6,3 +6,4 @@ ^README\.Rmd$ ^src-native$ ^html$ +^tests/fulltests/* diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f5c3a749fe0a1..e3528bc7c3135 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -334,7 +334,7 @@ setMethod("toDF", signature(x = "RDD"), #' #' Loads a JSON file, returning the result as a SparkDataFrame #' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' ) is supported. For JSON (one record per file), set a named property \code{multiLine} to #' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' @@ -348,7 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' df <- read.json(path, wholeFile = TRUE) +#' df <- read.json(path, multiLine = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -598,7 +598,7 @@ tableToDF <- function(tableName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 4ca7aa664e023..ec931befa2854 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -267,7 +267,7 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context sparkCachePath <- function() { - if (.Platform$OS.type == "windows") { + if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { stop(paste("%LOCALAPPDATA% not found.", diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea45e394500e8..91483a4d23d9b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -908,10 +908,6 @@ isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } -is_cran <- function() { - !identical(Sys.getenv("NOT_CRAN"), "true") -} - is_windows <- function() { .Platform$OS.type == "windows" } @@ -920,6 +916,6 @@ hadoop_home_set <- function() { !identical(Sys.getenv("HADOOP_HOME"), "") } -not_cran_or_windows_with_hadoop <- function() { - !is_cran() && (!is_windows() || hadoop_home_set()) +windows_with_hadoop <- function() { + !is_windows() || hadoop_home_set() } diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R new file mode 100644 index 0000000000000..de47162d5325f --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic tests for CRAN") + +test_that("create DataFrame from list or data.frame", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + i <- 4 + df <- createDataFrame(data.frame(dummy = 1:i)) + expect_equal(count(df), i) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + mtcarsdf <- createDataFrame(mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) + + sparkR.session.stop() +}) + +test_that("spark.glm and predict", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + sparkR.session.stop() +}) diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/jarTest.R rename to R/pkg/tests/fulltests/jarTest.R diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/packageInAJarTest.R rename to R/pkg/tests/fulltests/packageInAJarTest.R diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_Serde.R rename to R/pkg/tests/fulltests/test_Serde.R index 6e160fae1afed..6bbd201bf1d82 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -20,8 +20,6 @@ context("SerDe functionality") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { - skip_on_cran() - x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -40,8 +38,6 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { - skip_on_cran() - x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -69,8 +65,6 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { - skip_on_cran() - x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R similarity index 85% rename from R/pkg/inst/tests/testthat/test_Windows.R rename to R/pkg/tests/fulltests/test_Windows.R index 00d684e1a49ef..b2ec6c67311db 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -17,9 +17,7 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - skip_on_cran() - - if (.Platform$OS.type != "windows") { + if (!is_windows()) { skip("This test is only for Windows, skipped") } @@ -27,6 +25,3 @@ test_that("sparkJars tag in SparkContext", { abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) - -message("--- End test (Windows) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binaryFile.R rename to R/pkg/tests/fulltests/test_binaryFile.R index 00954fa31b0ee..758b174b8787c 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/tests/fulltests/test_binaryFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -40,8 +38,6 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -54,8 +50,6 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -80,8 +74,6 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_binary_function.R rename to R/pkg/tests/fulltests/test_binary_function.R index 236cb3885445e..442bed509bb1d 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -29,8 +29,6 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { - skip_on_cran() - actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -53,8 +51,6 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -73,8 +69,6 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_broadcast.R rename to R/pkg/tests/fulltests/test_broadcast.R index 2c96740df77bb..fc2c7c2deb825 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/tests/fulltests/test_broadcast.R @@ -26,8 +26,6 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcastRDD(sc, randomMat) @@ -40,8 +38,6 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_client.R rename to R/pkg/tests/fulltests/test_client.R index 3d53bebab6300..0cf25fe1dbf39 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/tests/fulltests/test_client.R @@ -18,8 +18,6 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -28,22 +26,16 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { - skip_on_cran() - expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_context.R rename to R/pkg/tests/fulltests/test_context.R index f6d9f5423df02..710485d56685a 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -18,8 +18,6 @@ context("test functions in sparkR.R") test_that("Check masked functions", { - skip_on_cran() - # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -57,8 +55,6 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { - skip_on_cran() - for (i in 1:4) { sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) @@ -77,8 +73,6 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -102,8 +96,6 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -116,16 +108,12 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { - skip_on_cran() - e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -153,8 +141,6 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { - skip_on_cran() - expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -182,8 +168,6 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_includePackage.R rename to R/pkg/tests/fulltests/test_includePackage.R index d7d9eeed1575e..f4ea0d1b5cb27 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -26,8 +26,6 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -44,8 +42,6 @@ test_that("include inside function", { }) test_that("use include package", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_jvm_api.R rename to R/pkg/tests/fulltests/test_jvm_api.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_classification.R rename to R/pkg/tests/fulltests/test_mllib_classification.R index 82e588dc460d0..726e9d9a20b1c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.svmLinear", { - skip_on_cran() - df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) @@ -51,7 +49,7 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -131,7 +129,7 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -228,8 +226,6 @@ test_that("spark.logit", { }) test_that("spark.mlp", { - skip_on_cran() - df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), @@ -250,7 +246,7 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -363,7 +359,7 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") write.ml(m, modelPath) expect_error(write.ml(m, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R rename to R/pkg/tests/fulltests/test_mllib_clustering.R index e827e961ab4c1..4110e13da4948 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.bisectingKmeans", { - skip_on_cran() - newIris <- iris newIris$Species <- NULL training <- suppressWarnings(createDataFrame(newIris)) @@ -55,7 +53,7 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -129,7 +127,7 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -177,7 +175,7 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -244,7 +242,7 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -265,8 +263,6 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -309,8 +305,6 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R rename to R/pkg/tests/fulltests/test_mllib_fpm.R index 4e10ca1e4f50b..69dda52f0c279 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -62,7 +62,7 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") write.ml(model, modelPath, overwrite = TRUE) loaded_model <- read.ml(modelPath) diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R rename to R/pkg/tests/fulltests/test_mllib_recommendation.R index cc8064f88d27a..4d919c9d746b0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R @@ -37,7 +37,7 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_mllib_regression.R rename to R/pkg/tests/fulltests/test_mllib_regression.R index b05fdd350ca28..82472c92b9965 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -23,8 +23,6 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -197,8 +195,6 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -226,8 +222,6 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -254,8 +248,6 @@ test_that("formula of glm", { }) test_that("glm and predict", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -300,8 +292,6 @@ test_that("glm and predict", { }) test_that("glm summary", { - skip_on_cran() - # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -351,8 +341,6 @@ test_that("glm summary", { }) test_that("glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -401,7 +389,7 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -452,7 +440,7 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_stat.R rename to R/pkg/tests/fulltests/test_mllib_stat.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_mllib_tree.R rename to R/pkg/tests/fulltests/test_mllib_tree.R index 31427ee52a5e9..9b3fc8d270b25 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.gbt", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) @@ -46,7 +44,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -80,7 +78,7 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -105,7 +103,7 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), source = "libsvm") model <- spark.gbt(data, label ~ features, "classification") @@ -144,7 +142,7 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -178,7 +176,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -215,7 +213,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.randomForest(data, label ~ features, "classification") @@ -224,8 +222,6 @@ test_that("spark.randomForest", { }) test_that("spark.decisionTree", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) @@ -242,7 +238,7 @@ test_that("spark.decisionTree", { expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -273,7 +269,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -309,7 +305,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("2.0", predictions)), 50) # spark.decisionTree classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.decisionTree(data, label ~ features, "classification") diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R rename to R/pkg/tests/fulltests/test_parallelize_collect.R index 52d4c93ed9599..3d122ccaf448f 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/tests/fulltests/test_parallelize_collect.R @@ -39,8 +39,6 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -68,8 +66,6 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -90,8 +86,6 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { - skip_on_cran() - # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -101,8 +95,6 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { - skip_on_cran() - # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_rdd.R rename to R/pkg/tests/fulltests/test_rdd.R index fb244e1d49e20..6ee1fceffd822 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -29,30 +29,22 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - skip_on_cran() - expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { - skip_on_cran() - expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - skip_on_cran() - expect_equal(countRDD(rdd), 10) expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { - skip_on_cran() - mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -64,40 +56,30 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { - skip_on_cran() - multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { - skip_on_cran() - sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { - skip_on_cran() - sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { - skip_on_cran() - flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { - skip_on_cran() - filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -113,8 +95,6 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { - skip_on_cran() - vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -123,8 +103,6 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { - skip_on_cran() - rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -139,8 +117,6 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { - skip_on_cran() - # RDD rdd2 <- rdd # PipelinedRDD @@ -182,8 +158,6 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { - skip_on_cran() - sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -193,8 +167,6 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { - skip_on_cran() - fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -203,8 +175,6 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { - skip_on_cran() - func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -221,14 +191,10 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { - skip_on_cran() - expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { - skip_on_cran() - # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -271,8 +237,6 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { - skip_on_cran() - multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -282,8 +246,6 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { - skip_on_cran() - l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -296,8 +258,6 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { - skip_on_cran() - pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -311,8 +271,6 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { - skip_on_cran() - nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -321,29 +279,21 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { - skip_on_cran() - max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { - skip_on_cran() - min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { - skip_on_cran() - sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { - skip_on_cran() - func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -351,8 +301,6 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -374,8 +322,6 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { - skip_on_cran() - sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -387,8 +333,6 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -401,8 +345,6 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -415,8 +357,6 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { - skip_on_cran() - actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -426,8 +366,6 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -441,8 +379,6 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -457,8 +393,6 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -473,32 +407,24 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { - skip_on_cran() - rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { - skip_on_cran() - keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { - skip_on_cran() - values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { - skip_on_cran() - actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -516,8 +442,6 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -547,8 +471,6 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -592,8 +514,6 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { - skip_on_cran() - l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -621,8 +541,6 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { - skip_on_cran() - l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -652,8 +570,6 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { - skip_on_cran() - # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -670,8 +586,6 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -696,8 +610,6 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -728,8 +640,6 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -757,8 +667,6 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -790,8 +698,6 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - skip_on_cran() - numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -841,8 +747,6 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { - skip_on_cran() - rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -861,15 +765,11 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - skip_on_cran() - rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -894,8 +794,6 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { - skip_on_cran() - rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_shuffle.R rename to R/pkg/tests/fulltests/test_shuffle.R index 18320ea44b389..98300c67c415f 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/tests/fulltests/test_shuffle.R @@ -37,8 +37,6 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { - skip_on_cran() - grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -48,8 +46,6 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { - skip_on_cran() - grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -59,8 +55,6 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { - skip_on_cran() - reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -70,8 +64,6 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { - skip_on_cran() - reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -80,8 +72,6 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { - skip_on_cran() - reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -91,8 +81,6 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { - skip_on_cran() - reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -101,8 +89,6 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { - skip_on_cran() - stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -115,8 +101,6 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { - skip_on_cran() - # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -145,8 +129,6 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { - skip_on_cran() - # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -190,8 +172,6 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { - skip_on_cran() - # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -207,8 +187,6 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { - skip_on_cran() - kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -227,8 +205,6 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { - skip_on_cran() - words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_sparkR.R rename to R/pkg/tests/fulltests/test_sparkR.R index a40981c188f7a..f73fc6baeccef 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/tests/fulltests/test_sparkR.R @@ -18,8 +18,6 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { - skip_on_cran() - # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_sparkSQL.R rename to R/pkg/tests/fulltests/test_sparkSQL.R index c790d02b107be..af529067f43e0 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -61,7 +61,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- if (not_cran_or_windows_with_hadoop()) { +sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster) } else { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) @@ -100,26 +100,20 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) -if (.Platform$OS.type == "windows") { +if (is_windows()) { Sys.setenv(TZ = "GMT") } test_that("calling sparkRSQL.init returns existing SQL context", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { - skip_on_cran() - expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { - skip_on_cran() - expect_equal(sparkR.session(), sparkSession) }) @@ -217,8 +211,6 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -316,8 +308,6 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { - skip_on_cran() - # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -330,7 +320,7 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -380,8 +370,6 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { - skip_on_cran() - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -436,8 +424,6 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -549,8 +535,6 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - skip_on_cran() - ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -563,8 +547,6 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { - skip_on_cran() - # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -607,7 +589,7 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { # Test read.df df <- read.df(jsonPath, "json") expect_is(df, "SparkDataFrame") @@ -654,8 +636,6 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -669,8 +649,6 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -730,8 +708,6 @@ test_that( }) test_that("test cache, uncache and clearCache", { - skip_on_cran() - df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -744,7 +720,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") dfParquet <- read.df(parquetPath, "parquet") @@ -787,8 +763,6 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -796,8 +770,6 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -808,8 +780,6 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { - skip_on_cran() - # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -839,8 +809,6 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { - skip_on_cran() - objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -853,8 +821,6 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -923,8 +889,6 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - skip_on_cran() - df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -964,7 +928,7 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { checkpointDir <- file.path(tempdir(), "cproot") expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) @@ -1341,7 +1305,7 @@ test_that("column calculation", { }) test_that("test HiveContext", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { setHiveContext(sc) schema <- structType(structField("name", "string"), structField("age", "integer"), @@ -1395,8 +1359,6 @@ test_that("column operators", { }) test_that("column functions", { - skip_on_cran() - c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) @@ -1782,8 +1744,6 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { }) test_that("group by, agg functions", { - skip_on_cran() - df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -2125,8 +2085,6 @@ test_that("filter() on a DataFrame", { }) test_that("join(), crossJoin() and merge() on a DataFrame", { - skip_on_cran() - df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -2400,8 +2358,6 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2423,8 +2379,6 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2440,7 +2394,7 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") # Test write.df and read.df write.df(df, parquetPath, "parquet", mode = "overwrite") @@ -2473,8 +2427,6 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { - skip_on_cran() - df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2492,8 +2444,6 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { - skip_on_cran() - # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2515,8 +2465,6 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2750,8 +2698,6 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { - skip_on_cran() - retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2760,8 +2706,6 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { - skip_on_cran() - expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2984,8 +2928,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { - skip_on_cran() - df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -3006,8 +2948,6 @@ test_that("dapplyCollect() on DataFrame with a binary column", { }) test_that("repartition by columns on DataFrame", { - skip_on_cran() - df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3046,8 +2986,6 @@ test_that("repartition by columns on DataFrame", { }) test_that("coalesce, repartition, numPartitions", { - skip_on_cran() - df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) expect_equal(getNumPartitions(coalesce(df, 3)), 3) @@ -3067,8 +3005,6 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - skip_on_cran() - df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3186,8 +3122,6 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -3221,8 +3155,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { }) test_that("randomSplit", { - skip_on_cran() - num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) weights <- c(2, 3, 5) @@ -3269,8 +3201,6 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { - skip_on_cran() - setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -3286,8 +3216,6 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { - skip_on_cran() - df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -3314,8 +3242,6 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { - skip_on_cran() - # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3440,8 +3366,6 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory - # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_streaming.R rename to R/pkg/tests/fulltests/test_streaming.R index b20b4312fbaae..d691de7cd725d 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -24,7 +24,7 @@ context("Structured Streaming") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") -if (.Platform$OS.type == "windows") { +if (is_windows()) { # file.path removes the empty separator on Windows, adds it back jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) } @@ -47,8 +47,6 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -69,8 +67,6 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { }) test_that("print from explain, lastProgress, status, isActive", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -90,8 +86,6 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { - skip_on_cran() - parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -118,8 +112,6 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { - skip_on_cran() - c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -129,8 +121,6 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { - skip_on_cran() - # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -139,8 +129,6 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_take.R rename to R/pkg/tests/fulltests/test_take.R index c00723ba31f4c..8936cc57da227 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/tests/fulltests/test_take.R @@ -34,8 +34,6 @@ sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FA sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - skip_on_cran() - numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_textFile.R rename to R/pkg/tests/fulltests/test_textFile.R index e8a961cb3e870..be2d2711ff88e 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/tests/fulltests/test_textFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -38,8 +36,6 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -50,8 +46,6 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -70,8 +64,6 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -86,8 +78,6 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -102,8 +92,6 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -115,8 +103,6 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -142,8 +128,6 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -157,8 +141,6 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_utils.R rename to R/pkg/tests/fulltests/test_utils.R index 6197ae7569879..af81423aa8dd0 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -23,7 +23,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { - skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -41,7 +40,6 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { - skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -169,8 +167,6 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - skip_on_cran() - method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "col", "unknown", TRUE), @@ -181,8 +177,6 @@ test_that("captureJVMException", { }) test_that("hashCode", { - skip_on_cran() - expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) @@ -243,6 +237,3 @@ test_that("basenameSansExtFromUrl", { }) sparkR.session.stop() - -message("--- End test (utils) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index f0bef4f6d2662..f00a610679752 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -24,8 +24,6 @@ options("warn" = 2) if (.Platform$OS.type == "windows") { Sys.setenv(TZ = "GMT") } -message("--- Start test ", as.POSIXct(Sys.time(), tz = "GMT")) -timer_ptm <- proc.time() # Setup global test environment # Install Spark first to set SPARK_HOME @@ -43,3 +41,11 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { } test_package("SparkR") + +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") +} diff --git a/core/pom.xml b/core/pom.xml index 7f245b5b6384a..326dde4f274bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -357,6 +357,34 @@ org.apache.commons commons-crypto + + + + ${hive.group} + hive-exec + provided + + + ${hive.group} + hive-metastore + provided + + + org.apache.thrift + libthrift + provided + + + org.apache.thrift + libfb303 + provided + + target/scala-${scala.binary.version}/classes diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef6656222455..5d48bc7c96555 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,173 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus(numPartitions: Int) { + + // All accesses to the following state must be guarded with `this.synchronized`. + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + // Exposed for testing + val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapId) == null) { + _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapId) = status + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + removeOutputsByFilter(x => x.host == host) + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + removeOutputsByFilter(x => x.executorId == execId) + } + + /** + * Removes all shuffle outputs which satisfies the filter. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle outputs. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + cachedSerializedBroadcast.destroy() + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -62,37 +229,26 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val mapStatuses: Map[Int, Array[MapStatus]] - - /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -116,14 +272,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) @@ -139,135 +288,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } - - /** - * Return statistics about all of the outputs for a given shuffle. - */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } - - if (fetchedStatuses == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatuses != null) { - return fetchedStatuses - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - return statuses - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. + * Deletes map output status information for the specified shuffle stage. */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() - } - } - } + def unregisterShuffle(shuffleId: Int): Unit - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - } - - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +332,13 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + // Exposed for testing + val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +384,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -363,59 +400,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) - if (changeEpoch) { - incrementEpoch() - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + def unregisterShuffle(shuffleId: Int) { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } + incrementEpoch() + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +523,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,77 +566,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch } } - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) + // This method is only called in local-mode. + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } - } - bytes + case None => + Seq.empty } } @@ -580,21 +591,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = + + val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -683,7 +794,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 53384e7373252..f10a41286c52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -367,7 +367,7 @@ private[deploy] class Master( drivers.find(_.id == driverId).foreach { driver => driver.worker = Some(worker) driver.state = DriverState.RUNNING - worker.drivers(driverId) = driver + worker.addDriver(driver) } } case None => @@ -547,6 +547,9 @@ private[deploy] class Master( workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + // Update the state of recovered apps to RUNNING + apps.filter(_.state == ApplicationState.WAITING).foreach(_.state = ApplicationState.RUNNING) + // Reschedule drivers which were not claimed by any workers drivers.filter(_.worker.isEmpty).foreach { d => logWarning(s"Driver ${d.id} was not found after master recovery") @@ -796,9 +799,19 @@ private[deploy] class Master( } private def relaunchDriver(driver: DriverInfo) { - driver.worker = None - driver.state = DriverState.RELAUNCHING - waitingDrivers += driver + // We must setup a new driver with a new driver id here, because the original driver may + // be still running. Consider this scenario: a worker is network partitioned with master, + // the master then relaunches driver driverID1 with a driver id driverID2, then the worker + // reconnects to master. From this point on, if driverID2 is equal to driverID1, then master + // can not distinguish the statusUpdate of the original driver and the newly relaunched one, + // for example, when DriverStateChanged(driverID1, KILLED) arrives at master, master will + // remove driverID1, so the newly relaunched driver disappears too. See SPARK-19900 for details. + removeDriver(driver.id, DriverState.RELAUNCHING, None) + val newDriver = createDriver(driver.desc) + persistenceEngine.addDriver(newDriver) + drivers.add(newDriver) + waitingDrivers += newDriver + schedule() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala similarity index 88% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 5adeb8e605ff4..35621daf9c0d7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import scala.reflect.runtime.universe import scala.util.control.NonFatal @@ -24,17 +24,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.token.{Token, TokenIdentifier} -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HBaseDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hbase" - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) @@ -55,7 +54,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide None } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..89b6f52ba4bca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + +/** + * Manages all the registered HadoopDelegationTokenProviders and offer APIs for other modules to + * obtain delegation tokens and their renewal time. By default [[HadoopFSDelegationTokenProvider]], + * [[HiveDelegationTokenProvider]] and [[HBaseDelegationTokenProvider]] will be loaded in if not + * explicitly disabled. + * + * Also, each HadoopDelegationTokenProvider is controlled by + * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * enabled/disabled by the configuration spark.security.credentials.hive.enabled. + * + * @param sparkConf Spark configuration + * @param hadoopConf Hadoop configuration + * @param fileSystems Delegation tokens will be fetched for these Hadoop filesystems. + */ +private[spark] class HadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) + extends Logging { + + private val deprecatedProviderEnabledConfigs = List( + "spark.yarn.security.tokens.%s.enabled", + "spark.yarn.security.credentials.%s.enabled") + private val providerEnabledConfig = "spark.security.credentials.%s.enabled" + + // Maintain all the registered delegation token providers + private val delegationTokenProviders = getDelegationTokenProviders + logDebug(s"Using the following delegation token providers: " + + s"${delegationTokenProviders.keys.mkString(", ")}.") + + private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { + val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), + new HiveDelegationTokenProvider, + new HBaseDelegationTokenProvider) + + // Filter out providers for which spark.security.credentials.{service}.enabled is false. + providers + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + def isServiceEnabled(serviceName: String): Boolean = { + val key = providerEnabledConfig.format(serviceName) + + deprecatedProviderEnabledConfigs.foreach { pattern => + val deprecatedKey = pattern.format(serviceName) + if (sparkConf.contains(deprecatedKey)) { + logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.") + } + } + + val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern => + sparkConf + .getOption(pattern.format(serviceName)) + .map(_.toBoolean) + .getOrElse(true) + } + + sparkConf + .getOption(key) + .map(_.toBoolean) + .getOrElse(isEnabledDeprecated) + } + + /** + * Get delegation token provider for the specified service. + */ + def getServiceDelegationTokenProvider(service: String): Option[HadoopDelegationTokenProvider] = { + delegationTokenProviders.get(service) + } + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Long = { + delegationTokenProviders.values.flatMap { provider => + if (provider.delegationTokensRequired(hadoopConf)) { + provider.obtainDelegationTokens(hadoopConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(Long.MaxValue)(math.min) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala new file mode 100644 index 0000000000000..f162e7e58c53a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials + +/** + * Hadoop delegation token provider. + */ +private[spark] trait HadoopDelegationTokenProvider { + + /** + * Name of the service to provide delegation tokens. This name should be unique. Spark will + * internally use this name to differentiate delegation token providers. + */ + def serviceName: String + + /** + * Returns true if delegation tokens are required for this service. By default, it is based on + * whether Hadoop security is enabled. + */ + def delegationTokensRequired(hadoopConf: Configuration): Boolean + + /** + * Obtain delegation tokens for this service and get the time of the next renewal. + * @param hadoopConf Configuration of current Hadoop Compatible system. + * @param creds Credentials to add tokens and security keys to. + * @return If the returned tokens are renewable and can be renewed, return the time of the next + * renewal, otherwise None should be returned. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala new file mode 100644 index 0000000000000..13157f33e2bf9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred.Master +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Set[FileSystem]) + extends HadoopDelegationTokenProvider with Logging { + + // This tokenRenewalInterval will be set in the first call to obtainDelegationTokens. + // If None, no token renewer is specified or no token can be renewed, + // so we cannot get the token renewal interval. + private var tokenRenewalInterval: Option[Long] = null + + override val serviceName: String = "hadoopfs" + + override def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] = { + + val newCreds = fetchDelegationTokens( + getTokenRenewer(hadoopConf), + fileSystems) + + // Get the token renewal interval if it is not set. It will only be called once. + if (tokenRenewalInterval == null) { + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, fileSystems) + } + + // Get the time of next renewal. + val nextRenewalDate = tokenRenewalInterval.flatMap { interval => + val nextRenewalDates = newCreds.getAllTokens.asScala + .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) + .map { token => + val identifier = token + .decodeIdentifier() + .asInstanceOf[AbstractDelegationTokenIdentifier] + identifier.getIssueDate + interval + } + if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) + } + + creds.addAll(newCreds) + nextRenewalDate + } + + def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + UserGroupInformation.isSecurityEnabled + } + + private def getTokenRenewer(hadoopConf: Configuration): String = { + val tokenRenewer = Master.getMasterPrincipal(hadoopConf) + logDebug("Delegation token renewer is: " + tokenRenewer) + + if (tokenRenewer == null || tokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer." + logError(errorMessage) + throw new SparkException(errorMessage) + } + + tokenRenewer + } + + private def fetchDelegationTokens( + renewer: String, + filesystems: Set[FileSystem]): Credentials = { + + val creds = new Credentials() + + filesystems.foreach { fs => + logInfo("getting token for: " + fs) + fs.addDelegationTokens(renewer, creds) + } + + creds + } + + private def getTokenRenewalInterval( + hadoopConf: Configuration, + filesystems: Set[FileSystem]): Option[Long] = { + // We cannot use the tokens generated with renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = fetchDelegationTokens( + UserGroupInformation.getCurrentUser.getUserName, + filesystems) + + val renewIntervals = creds.getAllTokens.asScala.filter { + _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] + }.flatMap { token => + Try { + val newExpiration = token.renew(hadoopConf) + val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") + interval + }.toOption + } + if (renewIntervals.isEmpty) None else Some(renewIntervals.min) + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala similarity index 54% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 16d8fc32bb42d..53b9f898c6e7d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -15,97 +15,89 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import java.lang.reflect.UndeclaredThrowableException import java.security.PrivilegedExceptionAction -import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HiveCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HiveDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" + private val classNotFoundErrorStr = s"You are attempting to use the " + + s"${getClass.getCanonicalName}, but your Spark distribution is not built with Hive libraries." + private def hiveConf(hadoopConf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down - // to a Configuration and used without reflection - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuration to be - // included in the hive config. - val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], - classOf[Object].getClass) - ctor.newInstance(hadoopConf, hiveConfClass).asInstanceOf[Configuration] + new HiveConf(hadoopConf, classOf[HiveConf]) } catch { case NonFatal(e) => logDebug("Fail to create Hive Configuration", e) hadoopConf + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + hadoopConf } } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled && hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty } - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { - val conf = hiveConf(hadoopConf) - - val principalKey = "hive.metastore.kerberos.principal" - val principal = conf.getTrimmed(principalKey, "") - require(principal.nonEmpty, s"Hive principal $principalKey undefined") - val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") - require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - - val currentUser = UserGroupInformation.getCurrentUser() - logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + - s"$principal at $metastoreUri") + try { + val conf = hiveConf(hadoopConf) - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - val closeCurrent = hiveClass.getMethod("closeCurrent") + val principalKey = "hive.metastore.kerberos.principal" + val principal = conf.getTrimmed(principalKey, "") + require(principal.nonEmpty, s"Hive principal $principalKey undefined") + val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") + require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - try { - // get all the instance methods before invoking any - val getDelegationToken = hiveClass.getMethod("getDelegationToken", - classOf[String], classOf[String]) - val getHive = hiveClass.getMethod("get", hiveConfClass) + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") doAsRealUser { - val hive = getHive.invoke(null, conf) - val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) - .asInstanceOf[String] + val hive = Hive.get(conf, classOf[HiveConf]) + val tokenStr = hive.getDelegationToken(currentUser.getUserName(), principal) + val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } + + None } catch { case NonFatal(e) => - logDebug(s"Fail to get token from service $serviceName", e) + logDebug(s"Failed to get token from service $serviceName", e) + None + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + None } finally { Utils.tryLogNonFatalError { - closeCurrent.invoke(null) + Hive.closeCurrent() } } - - None } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5b396687dd11a..19e7eb086f413 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -322,8 +322,14 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 515c9c24a9e2f..8f4c1b60920db 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -28,7 +28,7 @@ private object ConfigHelpers { def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = { try { - converter(s) + converter(s.trim) } catch { case _: NumberFormatException => throw new IllegalArgumentException(s"$key should be $configType, but was $s") @@ -37,7 +37,7 @@ private object ConfigHelpers { def toBoolean(s: String, key: String): Boolean = { try { - s.toBoolean + s.trim.toBoolean } catch { case _: IllegalArgumentException => throw new IllegalArgumentException(s"$key should be boolean, but was $s") diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7827e6760f355..84ef57f2d271b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -151,6 +151,14 @@ package object config { .createOptional // End blacklist confs + private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = + ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") + .doc("Whether to un-register all the outputs on the host in condition that we receive " + + " a FetchFailure. This is set default to false, which means, we only un-register the " + + " outputs related to the exact executor(instead of the host) on a FetchFailure.") + .booleanConf + .createWithDefault(false) + private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") .withAlternative("spark.scheduler.listenerbus.eventqueue.size") diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 5a5bd7fbbe2f8..cbee136871012 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution.PoissonDistribution /** * An ApproximateEvaluator for counts. @@ -48,22 +48,11 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) private[partial] object CountEvaluator { def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { - // Let the total count be N. A fraction p has been counted already, with sum 'sum', - // as if each element from the total data set had been seen with probability p. - val dist = - if (sum <= 10000) { - // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), - // where there have been 'sum' successes of probability p already. (There are several - // conventions, but this is the one followed by Commons Math3.) - new PascalDistribution(sum.toInt, p) - } else { - // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has - // a different interpretation. "sum" elements have been observed having scanned a fraction - // p of the data. This suggests data is counted at a rate of sum / p across the whole data - // set. The total expected count from the rest is distributed as - // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) - new PoissonDistribution(sum * (1 - p) / p) - } + // "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + val dist = new PoissonDistribution(sum * (1 - p) / p) // Not quite symmetric; calculate interval straight from discrete distribution val low = dist.inverseCumulativeProbability((1 - confidence) / 2) val high = dist.inverseCumulativeProbability((1 + confidence) / 2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ab2255f8a6654..fafe9cafdc18f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} @@ -187,6 +188,14 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ + private[scheduler] val unRegisterOutputOnHostOnFetchFailure = + sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + /** * Number of consecutive stage attempts allowed before a stage is aborted. */ @@ -328,25 +337,14 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") @@ -1217,7 +1215,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1234,16 +1233,14 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1343,13 +1340,26 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch)) + val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure) { + // We had a fetch failure with the external shuffle service, so we + // assume all shuffle data on the node is bad. + Some(bmAddress.host) + } else { + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None + } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch)) } } @@ -1383,32 +1393,42 @@ class DAGScheduler( */ private[scheduler] def handleExecutorLost( execId: String, - filesLost: Boolean, - maybeEpoch: Option[Long] = None) { + workerLost: Boolean): Unit = { + // if the cluster manager explicitly tells us that the entire worker was lost, then + // we know to unregister shuffle output. (Note that "worker" specifically refers to the process + // from a Standalone cluster, where the shuffle service lives in the Worker.) + val fileLost = workerLost || !env.blockManager.externalShuffleServiceEnabled + removeExecutorAndUnregisterOutputs( + execId = execId, + fileLost = fileLost, + hostToUnregisterOutputs = None, + maybeEpoch = None) + } + + private def removeExecutorAndUnregisterOutputs( + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None): Unit = { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - - if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { - logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + if (fileLost) { + hostToUnregisterOutputs match { + case Some(host) => + logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) + mapOutputTracker.removeOutputsOnHost(host) + case None => + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + mapOutputTracker.removeOutputsOnExecutor(execId) } clearCacheLocs() + + } else { + logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch)) } - } else { - logDebug("Additional executor lost message for " + execId + - "(epoch " + currentEpoch + ")") } } @@ -1701,11 +1721,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId, reason) => - val filesLost = reason match { + val workerLost = reason match { case SlaveLost(_, true) => true case _ => false } - dagScheduler.handleExecutorLost(execId, filesLost) + dagScheduler.handleExecutorLost(execId, workerLost) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270c..05f650fbf5df9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage( /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f3033e28b47d0..629cfc7c7a8ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b7cbed468517c..d63381c78bc3b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage( ++ -
++ + ++ ++ ++ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 4fe5c5e4fee4a..bc3d23e3fbb29 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -139,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 622f7985ba444..3931d53b4ae0a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, @@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index de719990cf47a..b089357e7b868 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -505,8 +505,8 @@ class SparkSubmitSuite assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "packageInAJarTest.R").mkString(File.separator) + val rScriptDir = Seq( + sparkHome, "R", "pkg", "tests", "fulltests", "packageInAJarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) IvyTestUtils.withRepository(main, None, None, withR = true) { repo => val args = Seq( @@ -527,7 +527,7 @@ class SparkSubmitSuite // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "testthat", "jarTest.R").mkString(File.separator) + Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) // compile a small jar containing a class that will be called from R code. diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 539264652d7d5..6bb0eec040787 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,14 +19,19 @@ package org.apache.spark.deploy.master import java.util.Date import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps +import scala.reflect.ClassTag import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} @@ -34,7 +39,51 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer + +object MockWorker { + val counter = new AtomicInteger(10000) +} + +class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint { + val seq = MockWorker.counter.incrementAndGet() + val id = seq.toString + override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq, + conf, new SecurityManager(conf)) + var apps = new mutable.HashMap[String, String]() + val driverIdToAppId = new mutable.HashMap[String, String]() + def newDriver(driverId: String): RpcEndpointRef = { + val name = s"driver_${drivers.size}" + rpcEnv.setupEndpoint(name, new RpcEndpoint { + override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId, _) => + apps(appId) = appId + driverIdToAppId(driverId) = appId + } + }) + } + + val appDesc = DeployTestUtils.createAppDesc() + val drivers = mutable.HashSet[String]() + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, _, _) => + masterRef.send(WorkerLatestState(id, Nil, drivers.toSeq)) + case LaunchDriver(driverId, desc) => + drivers += driverId + master.send(RegisterApplication(appDesc, newDriver(driverId))) + case KillDriver(driverId) => + master.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + drivers -= driverId + driverIdToAppId.get(driverId) match { + case Some(appId) => + apps.remove(appId) + master.send(UnregisterApplication(appId)) + } + driverIdToAppId.remove(driverId) + } +} class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -134,6 +183,81 @@ class MasterSuite extends SparkFunSuite CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts } + test("master correctly recover the application") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.deploy.recoveryMode", "CUSTOM") + conf.set("spark.deploy.recoveryMode.factory", + classOf[FakeRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") + + val fakeAppInfo = makeAppInfo(1024) + val fakeWorkerInfo = makeWorkerInfo(8192, 16) + val fakeDriverInfo = new DriverInfo( + startTime = 0, + id = "test_driver", + desc = new DriverDescription( + jarUrl = "", + mem = 1024, + cores = 1, + supervise = false, + command = new Command("", Nil, Map.empty, Nil, Nil, Nil)), + submitDate = new Date()) + + // Build the fake recovery data + FakeRecoveryModeFactory.persistentData.put(s"app_${fakeAppInfo.id}", fakeAppInfo) + FakeRecoveryModeFactory.persistentData.put(s"driver_${fakeDriverInfo.id}", fakeDriverInfo) + FakeRecoveryModeFactory.persistentData.put(s"worker_${fakeWorkerInfo.id}", fakeWorkerInfo) + + var master: Master = null + try { + master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + // Wait until Master recover from checkpoint data. + eventually(timeout(5 seconds), interval(100 milliseconds)) { + master.idToApp.size should be(1) + } + + master.idToApp.keySet should be(Set(fakeAppInfo.id)) + getDrivers(master) should be(Set(fakeDriverInfo)) + master.workers should be(Set(fakeWorkerInfo)) + + // Notify Master about the executor and driver info to make it correctly recovered. + val fakeExecutors = List( + new ExecutorDescription(fakeAppInfo.id, 0, 8, ExecutorState.RUNNING), + new ExecutorDescription(fakeAppInfo.id, 0, 7, ExecutorState.RUNNING)) + + fakeAppInfo.state should be(ApplicationState.UNKNOWN) + fakeWorkerInfo.coresFree should be(16) + fakeWorkerInfo.coresUsed should be(0) + + master.self.send(MasterChangeAcknowledged(fakeAppInfo.id)) + eventually(timeout(1 second), interval(10 milliseconds)) { + // Application state should be WAITING when "MasterChangeAcknowledged" event executed. + fakeAppInfo.state should be(ApplicationState.WAITING) + } + + master.self.send( + WorkerSchedulerStateResponse(fakeWorkerInfo.id, fakeExecutors, Seq(fakeDriverInfo.id))) + + eventually(timeout(5 seconds), interval(100 milliseconds)) { + getState(master) should be(RecoveryState.ALIVE) + } + + // If driver's resource is also counted, free cores should 0 + fakeWorkerInfo.coresFree should be(0) + fakeWorkerInfo.coresUsed should be(16) + // State of application should be RUNNING + fakeAppInfo.state should be(ApplicationState.RUNNING) + } finally { + if (master != null) { + master.rpcEnv.shutdown() + master.rpcEnv.awaitTermination() + master = null + FakeRecoveryModeFactory.persistentData.clear() + } + } + } + test("master/worker web ui available") { implicit val formats = org.json4s.DefaultFormats val conf = new SparkConf() @@ -394,6 +518,9 @@ class MasterSuite extends SparkFunSuite // ========================================== private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + private val _drivers = PrivateMethod[HashSet[DriverInfo]]('drivers) + private val _state = PrivateMethod[RecoveryState.Value]('state) + private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -412,12 +539,18 @@ class MasterSuite extends SparkFunSuite val desc = new ApplicationDescription( "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) val appId = System.currentTimeMillis.toString - new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new ApplicationInfo(0, appId, desc, new Date, endpointRef, Int.MaxValue) } private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { val workerId = System.currentTimeMillis.toString - new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80") + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new WorkerInfo(workerId, "host", 100, cores, memoryMb, endpointRef, "http://localhost:80") } private def scheduleExecutorsOnWorkers( @@ -499,4 +632,104 @@ class MasterSuite extends SparkFunSuite assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) } } + + test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") { + val conf = new SparkConf().set("spark.worker.timeout", "1") + val master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + val worker1 = new MockWorker(master.self) + worker1.rpcEnv.setupEndpoint("worker", worker1) + val worker1Reg = RegisterWorker( + worker1.id, + "localhost", + 9998, + worker1.self, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000)) + master.self.send(worker1Reg) + val driver = DeployTestUtils.createDriverDesc().copy(supervise = true) + master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver)) + + eventually(timeout(10.seconds)) { + assert(worker1.apps.nonEmpty) + } + + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers(0).state == WorkerState.DEAD) + } + + val worker2 = new MockWorker(master.self) + worker2.rpcEnv.setupEndpoint("worker", worker2) + master.self.send(RegisterWorker( + worker2.id, + "localhost", + 9999, + worker2.self, + 10, + 1024, + "http://localhost:8081", + RpcAddress("localhost", 10001))) + eventually(timeout(10.seconds)) { + assert(worker2.apps.nonEmpty) + } + + master.self.send(worker1Reg) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + + val worker = masterState.workers.filter(w => w.id == worker1.id) + assert(worker.length == 1) + // make sure the `DriverStateChanged` arrives at Master. + assert(worker(0).drivers.isEmpty) + assert(worker1.apps.isEmpty) + assert(worker1.drivers.isEmpty) + assert(worker2.apps.size == 1) + assert(worker2.drivers.size == 1) + assert(masterState.activeDrivers.length == 1) + assert(masterState.activeApps.length == 1) + } + } + + private def getDrivers(master: Master): HashSet[DriverInfo] = { + master.invokePrivate(_drivers()) + } + + private def getState(master: Master): RecoveryState.Value = { + master.invokePrivate(_state()) + } +} + +private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) + extends StandaloneRecoveryModeFactory(conf, ser) { + import FakeRecoveryModeFactory.persistentData + + override def createPersistenceEngine(): PersistenceEngine = new PersistenceEngine { + + override def unpersist(name: String): Unit = { + persistentData.remove(name) + } + + override def persist(name: String, obj: Object): Unit = { + persistentData(name) = obj + } + + override def read[T: ClassTag](prefix: String): Seq[T] = { + persistentData.filter(_._1.startsWith(prefix)).map(_._2.asInstanceOf[T]).toSeq + } + } + + override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } +} + +private object FakeRecoveryModeFactory { + val persistentData = new HashMap[String, Object]() } diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..335f3449cb782 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var delegationTokenManager: HadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + test("Correctly load default credential providers") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("bogus") should be (None) + } + + test("disable hive credential provider") { + sparkConf.set("spark.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + } + + test("using deprecated configurations") { + sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") + sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + } + + test("verify no credentials are obtained") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + val creds = new Credentials() + + // Tokens cannot be obtained from HDFS, Hive, HBase in unit tests. + delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + val tokens = creds.getAllTokens + tokens.size() should be (0) + } + + test("obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + + val hiveCredentialProvider = new HiveDelegationTokenProvider() + val credentials = new Credentials() + hiveCredentialProvider.obtainDelegationTokens(hadoopConf, credentials) + + credentials.getAllTokens.size() should be (0) + } + + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + + val hbaseTokenProvider = new HBaseDelegationTokenProvider() + val creds = new Credentials() + hbaseTokenProvider.obtainDelegationTokens(hadoopConf, creds) + + creds.getAllTokens.size should be (0) + } + + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { + Set(FileSystem.get(hadoopConf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 101a44edd8ee2..ce212a7513310 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.worker -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, ExecutorState} @@ -25,7 +25,7 @@ import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorState import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} -class WorkerSuite extends SparkFunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ @@ -34,6 +34,25 @@ class WorkerSuite extends SparkFunSuite with Matchers { } def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) + private var _worker: Worker = _ + + private def makeWorker(conf: SparkConf): Worker = { + assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) + _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "Worker", "/tmp", conf, securityMgr) + _worker + } + + after { + if (_worker != null) { + _worker.rpcEnv.shutdown() + _worker.rpcEnv.awaitTermination() + _worker = null + } + } + test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=true")) shouldBe true @@ -65,9 +84,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (small number of executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -91,9 +108,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (more executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -126,9 +141,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (small number of drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -152,9 +165,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (more drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 271ab8b148831..98259300381eb 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,7 +80,8 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests - actualPort should be <= (expectedPort + 10) + // the default value for `spark.port.maxRetries` is 100 under test + actualPort should be <= (expectedPort + 100) } private def createService(port: Int): NettyBlockTransferService = { diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala index da3256bd882e8..3c1208c2c375c 100644 --- a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -23,21 +23,21 @@ class CountEvaluatorSuite extends SparkFunSuite { test("test count 0") { val evaluator = new CountEvaluator(10, 0.95) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) evaluator.merge(1, 0) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) } test("test count >= 1") { val evaluator = new CountEvaluator(10, 0.95) evaluator.merge(1, 1) - assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(10.0, 0.95, 5.0, 16.0)) evaluator.merge(1, 3) - assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(20.0, 0.95, 13.0, 28.0)) evaluator.merge(1, 8) - assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(40.0, 0.95, 30.0, 51.0)) (4 to 10).foreach(_ => evaluator.merge(1, 10)) - assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(82.0, 1.0, 82.0, 82.0)) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 2802cd975292c..9e204f5cc33fe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.rdd +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -168,6 +172,10 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { // Collecting the RDD should now fail with an informative exception val blockId = RDDBlockId(rdd.id, numPartitions - 1) bmm.removeBlock(blockId) + // Wait until the block has been removed successfully. + eventually(timeout(1 seconds), interval(100 milliseconds)) { + assert(bmm.getBlockStatus(blockId).isEmpty) + } try { rdd.collect() fail("Collect should have failed if local checkpoint block is removed...") diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2b..571c6bbb4585d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 67145e7445061..ddd3281106745 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -396,6 +396,73 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("All shuffle files on the slave should be cleaned up when slave lost") { + // reset the test context with the right shuffle service config + afterEach() + val conf = new SparkConf() + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") + init(conf) + runEvent(ExecutorAdded("exec-hostA1", "hostA")) + runEvent(ExecutorAdded("exec-hostA2", "hostA")) + runEvent(ExecutorAdded("exec-hostB", "hostB")) + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3)) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(3)) + val secondShuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + // map stage1 completes successfully, with one task on each executor + complete(taskSets(0), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // map stage2 completes successfully, with one task on each executor + complete(taskSets(1), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // make sure our test setup is correct + val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus1.count(_ != null) === 3) + assert(initialMapStatus1.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus2.count(_ != null) === 3) + assert(initialMapStatus2.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + // reduce stage fails with a fetch failure from one host + complete(taskSets(2), Seq( + (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"), + null) + )) + + // Here is the main assertion -- make sure that we de-register + // the map outputs for both map stage from both executors on hostA + + val mapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + assert(mapStatus1.count(_ != null) === 1) + assert(mapStatus1(2).location.executorId === "exec-hostB") + assert(mapStatus1(2).location.host === "hostB") + + val mapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + assert(mapStatus2.count(_ != null) === 1) + assert(mapStatus2(2).location.executorId === "exec-hostB") + assert(mapStatus2(2).location.host === "hostB") + } + test("zero split job") { var numResults = 0 var failureReason: Option[Exception] = None diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 2355d40d1e6fe..607234b4068d0 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -93,16 +93,13 @@ INDEX .lintr gen-java.* .*avpr -org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet spark-deps-.* .*csv .*tsv -org.apache.spark.scheduler.ExternalClusterManager .*\.sql .Rbuildignore -org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ab1de3d3dd8ad..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -47,9 +47,9 @@ commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar -curator-client-2.6.0.jar -curator-framework-2.6.0.jar -curator-recipes-2.6.0.jar +curator-client-2.7.1.jar +curator-framework-2.7.1.jar +curator-recipes-2.7.1.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8745e76d127ae..ec130c1db8f5f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -382,8 +382,9 @@ See the [configuration page](configuration.html) for information on Spark config (none) Set the Mesos labels to add to each task. Labels are free-form key-value pairs. - Key-value pairs should be separated by a colon, and commas used to list more than one. - Ex. key:value,key2:value2. + Key-value pairs should be separated by a colon, and commas used to + list more than one. If your label includes a colon or comma, you + can escape it with a backslash. Ex. key:value,key2:a\:b. @@ -468,6 +469,15 @@ See the [configuration page](configuration.html) for information on Spark config If unset it will point to Spark's internal web UI. + + spark.mesos.driver.labels + (none) + + Mesos labels to add to the driver. See spark.mesos.task.labels + for formatting information. + + + spark.mesos.driverEnv.[EnvironmentVariableName] (none) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2d56123028f2b..e4a74556d4f26 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -419,7 +419,7 @@ To use a custom metrics.properties for the application master and executors, upd - spark.yarn.security.credentials.${service}.enabled + spark.security.credentials.${service}.enabled true Controls whether to obtain credentials for services when security is enabled. @@ -482,11 +482,11 @@ token for the cluster's default Hadoop filesystem, and potentially for HBase and An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.yarn.security.credentials.hbase.enabled` is not set to `false`. +and `spark.security.credentials.hbase.enabled` is not set to `false`. Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.yarn.security.credentials.hive.enabled` is not set to `false`. +`spark.security.credentials.hive.enabled` is not set to `false`. If an application needs to interact with other secure Hadoop filesystems, then the tokens needed to access these clusters must be explicitly requested at @@ -500,7 +500,7 @@ Spark supports integrating with other security-aware services through Java Servi `java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` should be available to Spark by listing their names in the corresponding file in the jar's `META-INF/services` directory. These plug-ins can be disabled by setting -`spark.yarn.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of +`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of credential provider. ## Configuring the External Shuffle Service @@ -564,8 +564,8 @@ the Spark configuration must be set to disable token collection for the services The Spark configuration must include the lines: ``` -spark.yarn.security.credentials.hive.enabled false -spark.yarn.security.credentials.hbase.enabled false +spark.security.credentials.hive.enabled false +spark.security.credentials.hbase.enabled false ``` The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 314ff6ef80d29..8e722ae6adca6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -998,7 +998,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -1012,7 +1012,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} @@ -1025,7 +1025,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. +For a regular multi-line JSON file, set the `multiLine` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %} @@ -1039,7 +1039,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. +For a regular multi-line JSON file, set a named parameter `multiLine` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 6a25c9939c264..9b9177d44145f 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1056,7 +1056,7 @@ Some of them are as follows. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). -- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy.count()` which returns a streaming Dataset containing a running count. +- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy().count()` which returns a streaming Dataset containing a running count. - `foreach()` - Instead use `ds.writeStream.foreach(...)` (see next section). diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 4ca062c0b5adf..b6909b3386b71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = instance.wordVectors.getVectors val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) sparkSession.createDataFrame(dataSeq) - .repartition(calculateNumberOfPartitions) + .repartition(numPartitions) .write .parquet(dataPath) } + } - def calculateNumberOfPartitions(): Int = { - val floatSize = 4 + private[feature] + object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ + def calculateNumberOfPartitions( + bufferSizeInBytes: Long, + numWords: Int, + vectorSize: Int): Int = { + val floatSize = 4L // Use Long to help avoid overflow val averageWordSize = 15 - // [SPARK-11994] - We want to partition the model in partitions smaller than - // spark.kryoserializer.buffer.max - val bufferSizeInBytes = Utils.byteStringAsBytes( - sc.conf.get("spark.kryoserializer.buffer.max", "64m")) // Calculate the approximate size of the model. // Assuming an average word size of 15 bytes, the formula is: // (floatSize * vectorSize + 15) * numWords - val numWords = instance.wordVectors.wordIndex.size - val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords - ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt + val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords + val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1 + require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " + + s"partitions to save this model, which is too large. Try increasing " + + s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.") + numPartitions.toInt } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a6a1c2b4f32bd..6183606a7b2ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } + test("Word2Vec read/write numPartitions calculation") { + val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5) + assert(smallModelNumPartitions === 1) + val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000) + assert(largeModelNumPartitions > 1) + } + test("Word2Vec read/write") { val t = new Word2Vec() .setInputCol("myInputCol") diff --git a/pom.xml b/pom.xml index 6835ea14cd42b..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -2532,6 +2532,7 @@ hadoop-2.7 2.7.3 + 2.7.1 diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 5cf719bd65ae4..aef71f9ca7001 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -174,12 +174,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - wholeFile=None): + multiLine=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. - For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. + For JSON (one record per file), set the ``multiLine`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -230,7 +230,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. - :param wholeFile: parse one record, which may span multiple lines, per file. If None is + :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') @@ -248,7 +248,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, wholeFile=wholeFile) + timestampFormat=timestampFormat, multiLine=multiLine) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -322,7 +322,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, wholeFile=None): + columnNameOfCorruptRecord=None, multiLine=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -396,7 +396,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. - :param wholeFile: parse records, which may span multiple lines. If None is + :param multiLine: parse records, which may span multiple lines. If None is set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') @@ -411,7 +411,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 76e8c4f47d8ad..58aa2468e006d 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -401,12 +401,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - wholeFile=None): + multiLine=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. - For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. + For JSON (one record per file), set the ``multiLine`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -458,7 +458,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. - :param wholeFile: parse one record, which may span multiple lines, per file. If None is + :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) @@ -473,7 +473,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, wholeFile=wholeFile) + timestampFormat=timestampFormat, multiLine=multiLine) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -532,7 +532,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, wholeFile=None): + columnNameOfCorruptRecord=None, multiLine=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -607,7 +607,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. - :param wholeFile: parse one record, which may span multiple lines. If None is + :param multiLine: parse one record, which may span multiple lines. If None is set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) @@ -624,7 +624,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 845e1c7619cc4..31f932a363225 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -457,15 +457,15 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) - def test_wholefile_json(self): + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", - wholeFile=True) + multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) - def test_wholefile_csv(self): + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( - "python/test_support/sql/ages_newlines.csv", wholeFile=True) + "python/test_support/sql/ages_newlines.csv", multiLine=True) expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb13de563cdd4..e1c5f007268a7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1,3 +1,4 @@ +# coding=utf-8 # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -1859,6 +1860,31 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + def test_exception_blocking(self): + """ + SPARK-21045 + Test whether program is blocked when occur exception in worker sending + exception to PythonRDD + + """ + import threading + + def run(): + try: + + def f(): + raise Exception("中") + + self.sc.parallelize([1]).map(lambda x: f()).count() + except Exception: + pass + + t = threading.Thread(target=run) + t.daemon = True + t.start() + t.join(10) + self.assertFalse(t.isAlive(), 'Spark executor is blocked.') + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index baaa3fe074e9a..11c6555b1fdc9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -36,6 +36,9 @@ pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() +if sys.version >= '3': + unicode = str + def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) @@ -177,8 +180,11 @@ def process(): process() except Exception: try: + exc_info = traceback.format_exc() + if isinstance(exc_info, unicode): + exc_info = exc_info.encode('utf-8') write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc().encode("utf-8"), outfile) + write_with_length(exc_info, outfile) except IOError: # JVM close the socket pass diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 19e253394f1b2..56d697f359614 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -56,4 +56,11 @@ package object config { .stringConf .createOptional + private [spark] val DRIVER_LABELS = + ConfigBuilder("spark.mesos.driver.labels") + .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" + + "pairs should be separated by a colon, and commas used to list more than one." + + "Ex. key:value,key2:value2") + .stringConf + .createOptional } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1bc6f71860c3f..577f9a876b381 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,11 +30,13 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils + /** * Tracks the current state of a Mesos Task that runs a Spark driver. * @param driverDescription Submitted driver description from @@ -525,15 +527,17 @@ private[spark] class MesosClusterScheduler( offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") - val taskInfo = TaskInfo.newBuilder() + + TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) - taskInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) - taskInfo.build + .setLabels(MesosProtoUtils.mesosLabels(desc.conf.get(config.DRIVER_LABELS).getOrElse(""))) + .setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) + .build } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index ac7aec7b0a034..871685c6cccc0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -419,16 +419,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName(s"${sc.appName} $taskId") - - taskBuilder.addAllResources(resourcesToUse.asJava) - taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) - - val labelsBuilder = taskBuilder.getLabelsBuilder - val labels = buildMesosLabels().asJava - - labelsBuilder.addAllLabels(labels) - - taskBuilder.setLabels(labelsBuilder) + .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) + .addAllResources(resourcesToUse.asJava) + .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava @@ -444,21 +437,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } - private def buildMesosLabels(): List[Label] = { - taskLabels.split(",").flatMap(label => - label.split(":") match { - case Array(key, value) => - Some(Label.newBuilder() - .setKey(key) - .setValue(value) - .build()) - case _ => - logWarning(s"Unable to parse $label into a key:value label for the task.") - None - } - ).toList - } - /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala new file mode 100644 index 0000000000000..fea01c7068c9a --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConverters._ + +import org.apache.mesos.Protos + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +object MesosProtoUtils extends Logging { + + /** Parses a label string of the format specified in spark.mesos.task.labels. */ + def mesosLabels(labelsStr: String): Protos.Labels.Builder = { + val labels: Seq[Protos.Label] = if (labelsStr == "") { + Seq() + } else { + labelsStr.split("""(? + val parts = labelStr.split("""(? part.replaceAll("""\\,""", ",")) + .map(part => part.replaceAll("""\\:""", ":")) + + Protos.Label.newBuilder() + .setKey(cleanedParts(0)) + .setValue(cleanedParts(1)) + .build() + } + } + + Protos.Labels.newBuilder().addAllLabels(labels.asJava) + } +} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 32967b04cd346..0bb47906347d5 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -248,6 +248,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(networkInfos.get(0).getName == "test-network-name") } + test("supports spark.mesos.driver.labels") { + setScheduler() + + val mem = 1000 + val cpu = 1 + + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + "spark.mesos.driver.labels" -> "key:value"), + "s1", + new Date())) + + assert(response.success) + + val offer = Utils.createOffer("o1", "s1", mem, cpu) + scheduler.resourceOffers(driver, List(offer).asJava) + + val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") + val labels = launchedTasks.head.getLabels + assert(labels.getLabelsCount == 1) + assert(labels.getLabels(0).getKey == "key") + assert(labels.getLabels(0).getValue == "value") + } + test("can kill supervised drivers") { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 0418bfbaa5ed8..7cca5fedb31eb 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -532,29 +532,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getLabels.equals(taskLabels)) } - test("mesos ignored invalid labels and sets configurable labels on tasks") { - val taskLabelsString = "mesos:test,label:test,incorrect:label:here" - setBackend(Map( - "spark.mesos.task.labels" -> taskLabelsString - )) - - // Build up the labels - val taskLabels = Protos.Labels.newBuilder() - .addLabels(Protos.Label.newBuilder() - .setKey("mesos").setValue("test").build()) - .addLabels(Protos.Label.newBuilder() - .setKey("label").setValue("test").build()) - .build() - - val offers = List(Resources(backend.executorMemory(sc), 1)) - offerResources(offers) - val launchedTasks = verifyTaskLaunched(driver, "o1") - - val labels = launchedTasks.head.getLabels - - assert(launchedTasks.head.getLabels.equals(taskLabels)) - } - test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala new file mode 100644 index 0000000000000..36a4c1ab1ad25 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.spark.SparkFunSuite + +class MesosProtoUtilsSuite extends SparkFunSuite { + test("mesosLabels") { + val labels = MesosProtoUtils.mesosLabels("key:value") + assert(labels.getLabelsCount == 1) + val label = labels.getLabels(0) + assert(label.getKey == "key") + assert(label.getValue == "value") + + val labels2 = MesosProtoUtils.mesosLabels("key:value\\:value") + assert(labels2.getLabelsCount == 1) + val label2 = labels2.getLabels(0) + assert(label2.getKey == "key") + assert(label2.getValue == "value:value") + + val labels3 = MesosProtoUtils.mesosLabels("key:value,key2:value2") + assert(labels3.getLabelsCount == 2) + assert(labels3.getLabels(0).getKey == "key") + assert(labels3.getLabels(0).getValue == "value") + assert(labels3.getLabels(1).getKey == "key2") + assert(labels3.getLabels(1).getValue == "value2") + + val labels4 = MesosProtoUtils.mesosLabels("key:value\\,value") + assert(labels4.getLabelsCount == 1) + assert(labels4.getLabels(0).getKey == "key") + assert(labels4.getLabels(0).getValue == "value,value") + } +} diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 71d4ad681e169..43a7ce95bd3de 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -167,29 +167,27 @@ ${jersey-1.version} - + ${hive.group} hive-exec - test + provided ${hive.group} hive-metastore - test + provided org.apache.thrift libthrift - test + provided org.apache.thrift libfb303 - test + provided diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider deleted file mode 100644 index f5a807ecac9d7..0000000000000 --- a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ /dev/null @@ -1,3 +0,0 @@ -org.apache.spark.deploy.yarn.security.HadoopFSCredentialProvider -org.apache.spark.deploy.yarn.security.HBaseCredentialProvider -org.apache.spark.deploy.yarn.security.HiveCredentialProvider diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6da2c0b5f330a..4f71a1606312d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -38,7 +38,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, ConfigurableCredentialManager} +import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -247,8 +247,12 @@ private[spark] class ApplicationMaster( if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) { // If a principal and keytab have been set, use that to create new credentials for executors // periodically - credentialRenewer = - new ConfigurableCredentialManager(sparkConf, yarnConf).credentialRenewer() + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + yarnConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, yarnConf)) + + val credentialRenewer = new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) credentialRenewer.scheduleLoginFromKeytab() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fb7edf2a6e30..e5131e636dc04 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -49,7 +49,7 @@ import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} @@ -121,7 +121,10 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) + private val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) @@ -368,7 +371,7 @@ private[spark] class Client( val fs = destDir.getFileSystem(hadoopConf) // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) + val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 0fc994d629ccb..4522071bd92e2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,8 +24,9 @@ import java.util.regex.Pattern import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants @@ -35,11 +36,14 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.security.{ConfigurableCredentialManager, CredentialUpdater} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.deploy.yarn.security.CredentialUpdater +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils + /** * Contains util methods to interact with Hadoop from spark. */ @@ -87,8 +91,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = { - credentialUpdater = - new ConfigurableCredentialManager(sparkConf, newConfiguration(sparkConf)).credentialUpdater() + val hadoopConf = newConfiguration(sparkConf) + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) credentialUpdater.start() } @@ -103,6 +111,21 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) ConverterUtils.toContainerId(containerIdString) } + + /** The filesystems for which YARN should fetch delegation tokens. */ + private[spark] def hadoopFSsToAccess( + sparkConf: SparkConf, + hadoopConf: Configuration): Set[FileSystem] = { + val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) + .map(new Path(_).getFileSystem(hadoopConf)) + .toSet + + val stagingFS = sparkConf.get(STAGING_DIR) + .map(new Path(_).getFileSystem(hadoopConf)) + .getOrElse(FileSystem.get(hadoopConf)) + + filesystemsToAccess + stagingFS + } } object YarnSparkHadoopUtil { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 7e76f402db249..68a2e9e70a78b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -54,7 +54,7 @@ import org.apache.spark.util.ThreadUtils private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { private var lastCredentialsFileSuffix = 0 @@ -174,7 +174,9 @@ private[yarn] class AMCredentialRenewer( keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainCredentials(freshHadoopConf, tempCreds) + nearestNextRenewalTime = credentialManager.obtainDelegationTokens( + freshHadoopConf, + tempCreds) null } }) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala deleted file mode 100644 index 4f4be52a0d691..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * A ConfigurableCredentialManager to manage all the registered credential providers and offer - * APIs for other modules to obtain credentials as well as renewal time. By default - * [[HadoopFSCredentialProvider]], [[HiveCredentialProvider]] and [[HBaseCredentialProvider]] will - * be loaded in if not explicitly disabled, any plugged-in credential provider wants to be - * managed by ConfigurableCredentialManager needs to implement [[ServiceCredentialProvider]] - * interface and put into resources/META-INF/services to be loaded by ServiceLoader. - * - * Also each credential provider is controlled by - * spark.yarn.security.credentials.{service}.enabled, it will not be loaded in if set to false. - * For example, Hive's credential provider [[HiveCredentialProvider]] can be enabled/disabled by - * the configuration spark.yarn.security.credentials.hive.enabled. - */ -private[yarn] final class ConfigurableCredentialManager( - sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { - private val deprecatedProviderEnabledConfig = "spark.yarn.security.tokens.%s.enabled" - private val providerEnabledConfig = "spark.yarn.security.credentials.%s.enabled" - - // Maintain all the registered credential providers - private val credentialProviders = { - val providers = ServiceLoader.load(classOf[ServiceCredentialProvider], - Utils.getContextOrSparkClassLoader).asScala - - // Filter out credentials in which spark.yarn.security.credentials.{service}.enabled is false. - providers.filter { p => - sparkConf.getOption(providerEnabledConfig.format(p.serviceName)) - .orElse { - sparkConf.getOption(deprecatedProviderEnabledConfig.format(p.serviceName)).map { c => - logWarning(s"${deprecatedProviderEnabledConfig.format(p.serviceName)} is deprecated, " + - s"using ${providerEnabledConfig.format(p.serviceName)} instead") - c - } - }.map(_.toBoolean).getOrElse(true) - }.map { p => (p.serviceName, p) }.toMap - } - - /** - * Get credential provider for the specified service. - */ - def getServiceCredentialProvider(service: String): Option[ServiceCredentialProvider] = { - credentialProviders.get(service) - } - - /** - * Obtain credentials from all the registered providers. - * @return nearest time of next renewal, Long.MaxValue if all the credentials aren't renewable, - * otherwise the nearest renewal time of any credentials will be returned. - */ - def obtainCredentials(hadoopConf: Configuration, creds: Credentials): Long = { - credentialProviders.values.flatMap { provider => - if (provider.credentialsRequired(hadoopConf)) { - provider.obtainCredentials(hadoopConf, sparkConf, creds) - } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None - } - }.foldLeft(Long.MaxValue)(math.min) - } - - /** - * Create an [[AMCredentialRenewer]] instance, caller should be responsible to stop this - * instance when it is not used. AM will use it to renew credentials periodically. - */ - def credentialRenewer(): AMCredentialRenewer = { - new AMCredentialRenewer(sparkConf, hadoopConf, this) - } - - /** - * Create an [[CredentialUpdater]] instance, caller should be resposible to stop this intance - * when it is not used. Executors and driver (client mode) will use it to update credentials. - * periodically. - */ - def credentialUpdater(): CredentialUpdater = { - new CredentialUpdater(sparkConf, hadoopConf, this) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 41b7b5d60b038..fe173dffc22a8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CredentialUpdater( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { @volatile private var lastCredentialsFileSuffix = 0 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala deleted file mode 100644 index f65c886db944e..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapred.Master -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ - -private[security] class HadoopFSCredentialProvider - extends ServiceCredentialProvider with Logging { - // Token renewal interval, this value will be set in the first call, - // if None means no token renewer specified or no token can be renewed, - // so cannot get token renewal interval. - private var tokenRenewalInterval: Option[Long] = null - - override val serviceName: String = "hadoopfs" - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - // NameNode to access, used to get tokens from different FileSystems - val tmpCreds = new Credentials() - val tokenRenewer = getTokenRenewer(hadoopConf) - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - logInfo("getting token for: " + dst) - dstFs.addDelegationTokens(tokenRenewer, tmpCreds) - } - - // Get the token renewal interval if it is not set. It will only be called once. - if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf) - } - - // Get the time of next renewal. - val nextRenewalDate = tokenRenewalInterval.flatMap { interval => - val nextRenewalDates = tmpCreds.getAllTokens.asScala - .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) - .map { t => - val identifier = t.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - identifier.getIssueDate + interval - } - if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) - } - - creds.addAll(tmpCreds) - nextRenewalDate - } - - private def getTokenRenewalInterval( - hadoopConf: Configuration, sparkConf: SparkConf): Option[Long] = { - // We cannot use the tokens generated with renewer yarn. Trying to renew - // those will fail with an access control issue. So create new tokens with the logged in - // user as renewer. - sparkConf.get(PRINCIPAL).flatMap { renewer => - val creds = new Credentials() - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - dstFs.addDelegationTokens(renewer, creds) - } - - val renewIntervals = creds.getAllTokens.asScala.filter { - _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] - }.flatMap { token => - Try { - val newExpiration = token.renew(hadoopConf) - val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - val interval = newExpiration - identifier.getIssueDate - logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") - interval - }.toOption - } - if (renewIntervals.isEmpty) None else Some(renewIntervals.min) - } - } - - private def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - - delegTokenRenewer - } - - private def hadoopFSsToAccess(hadoopConf: Configuration, sparkConf: SparkConf): Set[Path] = { - sparkConf.get(FILESYSTEMS_TO_ACCESS).map(new Path(_)).toSet + - sparkConf.get(STAGING_DIR).map(new Path(_)) - .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala index 4e3fcce8dbb1d..cc24ac4d9bcf6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala @@ -35,7 +35,7 @@ trait ServiceCredentialProvider { def serviceName: String /** - * To decide whether credential is required for this service. By default it based on whether + * Returns true if credentials are required by this service. By default, it is based on whether * Hadoop security is enabled. */ def credentialsRequired(hadoopConf: Configuration): Boolean = { @@ -44,6 +44,7 @@ trait ServiceCredentialProvider { /** * Obtain credentials for this service and get the time of the next renewal. + * * @param hadoopConf Configuration of current Hadoop Compatible system. * @param sparkConf Spark configuration. * @param creds Credentials to add tokens and security keys to. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..bbd17c8fc1272 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * This class loads delegation token providers registered under the YARN-specific + * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined + * in [[HadoopDelegationTokenManager]]. + */ +private[yarn] class YARNHadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) extends Logging { + + private val delegationTokenManager = + new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + + // public for testing + val credentialProviders = getCredentialProviders + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens(hadoopConf: Configuration, creds: Credentials): Long = { + val superInterval = delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + + credentialProviders.values.flatMap { provider => + if (provider.credentialsRequired(hadoopConf)) { + provider.obtainCredentials(hadoopConf, sparkConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(superInterval)(math.min) + } + + private def getCredentialProviders: Map[String, ServiceCredentialProvider] = { + val providers = loadCredentialProviders + + providers. + filter { p => delegationTokenManager.isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + private def loadCredentialProviders: List[ServiceCredentialProvider] = { + ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) + .asScala + .toList + } +} diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider index d0ef5efa36e86..f31c232693133 100644 --- a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider @@ -1 +1 @@ -org.apache.spark.deploy.yarn.security.TestCredentialProvider +org.apache.spark.deploy.yarn.security.YARNTestCredentialProvider diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala deleted file mode 100644 index b0067aa4517c7..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.Text -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.Token -import org.scalatest.{BeforeAndAfter, Matchers} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.config._ - -class ConfigurableCredentialManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private var credentialManager: ConfigurableCredentialManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null - - override def beforeAll(): Unit = { - super.beforeAll() - - sparkConf = new SparkConf() - hadoopConf = new Configuration() - System.setProperty("SPARK_YARN_MODE", "true") - } - - override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") - - super.afterAll() - } - - test("Correctly load default credential providers") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should not be (None) - } - - test("disable hive credential provider") { - sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - } - - test("using deprecated configurations") { - sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") - sparkConf.set("spark.yarn.security.tokens.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - credentialManager.getServiceCredentialProvider("test") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - } - - test("verify obtaining credentials from provider") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - // Tokens can only be obtained from TestTokenProvider, for hdfs, hbase and hive tokens cannot - // be obtained. - credentialManager.obtainCredentials(hadoopConf, creds) - val tokens = creds.getAllTokens - tokens.size() should be (1) - tokens.iterator().next().getService should be (new Text("test")) - } - - test("verify getting credential renewal info") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - val testCredentialProvider = credentialManager.getServiceCredentialProvider("test").get - .asInstanceOf[TestCredentialProvider] - // Only TestTokenProvider can get the time of next token renewal - val nextRenewal = credentialManager.obtainCredentials(hadoopConf, creds) - nextRenewal should be (testCredentialProvider.timeOfNextTokenRenewal) - } - - test("obtain tokens For HiveMetastore") { - val hadoopConf = new Configuration() - hadoopConf.set("hive.metastore.kerberos.principal", "bob") - // thrift picks up on port 0 and bails out, without trying to talk to endpoint - hadoopConf.set("hive.metastore.uris", "http://localhost:0") - - val hiveCredentialProvider = new HiveCredentialProvider() - val credentials = new Credentials() - hiveCredentialProvider.obtainCredentials(hadoopConf, sparkConf, credentials) - - credentials.getAllTokens.size() should be (0) - } - - test("Obtain tokens For HBase") { - val hadoopConf = new Configuration() - hadoopConf.set("hbase.security.authentication", "kerberos") - - val hbaseTokenProvider = new HBaseCredentialProvider() - val creds = new Credentials() - hbaseTokenProvider.obtainCredentials(hadoopConf, sparkConf, creds) - - creds.getAllTokens.size should be (0) - } -} - -class TestCredentialProvider extends ServiceCredentialProvider { - val tokenRenewalInterval = 86400 * 1000L - var timeOfNextTokenRenewal = 0L - - override def serviceName: String = "test" - - override def credentialsRequired(conf: Configuration): Boolean = true - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - if (creds == null) { - // Guard out other unit test failures. - return None - } - - val emptyToken = new Token() - emptyToken.setService(new Text("test")) - creds.addToken(emptyToken.getService, emptyToken) - - val currTime = System.currentTimeMillis() - timeOfNextTokenRenewal = (currTime - currTime % tokenRenewalInterval) + tokenRenewalInterval - - Some(timeOfNextTokenRenewal) - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala deleted file mode 100644 index f50ee193c258f..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.scalatest.{Matchers, PrivateMethodTester} - -import org.apache.spark.{SparkException, SparkFunSuite} - -class HadoopFSCredentialProviderSuite - extends SparkFunSuite - with PrivateMethodTester - with Matchers { - private val _getTokenRenewer = PrivateMethod[String]('getTokenRenewer) - - private def getTokenRenewer( - fsCredentialProvider: HadoopFSCredentialProvider, conf: Configuration): String = { - fsCredentialProvider invokePrivate _getTokenRenewer(conf) - } - - private var hadoopFsCredentialProvider: HadoopFSCredentialProvider = null - - override def beforeAll() { - super.beforeAll() - - if (hadoopFsCredentialProvider == null) { - hadoopFsCredentialProvider = new HadoopFSCredentialProvider() - } - } - - override def afterAll() { - if (hadoopFsCredentialProvider != null) { - hadoopFsCredentialProvider = null - } - - super.afterAll() - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..2b226eff5ce19 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil + +class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var credentialManager: YARNHadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + System.setProperty("SPARK_YARN_MODE", "true") + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + override def afterAll(): Unit = { + super.afterAll() + + System.clearProperty("SPARK_YARN_MODE") + } + + test("Correctly loads credential providers") { + credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + + credentialManager.credentialProviders.get("yarn-test") should not be (None) + } +} + +class YARNTestCredentialProvider extends ServiceCredentialProvider { + override def serviceName: String = "yarn-test" + + override def credentialsRequired(conf: Configuration): Boolean = true + + override def obtainCredentials( + hadoopConf: Configuration, + sparkConf: SparkConf, + creds: Credentials): Option[Long] = None +} diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8d80f8eca5dba..36948ba52b064 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -131,6 +131,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 43f7ff5cb4a36..ef5648c6dbe47 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -563,6 +563,7 @@ primaryExpression | CAST '(' expression AS dataType ')' #cast | FIRST '(' expression (IGNORE NULLS)? ')' #first | LAST '(' expression (IGNORE NULLS)? ')' #last + | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star @@ -720,6 +721,7 @@ nonReserved | SET | RESET | VIEW | REPLACE | IF + | POSITION | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -850,6 +852,7 @@ MACRO: 'MACRO'; IGNORE: 'IGNORE'; IF: 'IF'; +POSITION: 'POSITION'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319ec3f..7683ee7074e7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -267,16 +267,11 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val array = - Invoke( - MapObjects( - p => deserializerFor(et, Some(p)), - getPath, - inferDataType(et)._1), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1, + customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 87130532c89bc..d580cf4d3391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyData = - Invoke( - MapObjects( - p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - val valueData = - Invoke( - MapObjects( - p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), - returnNullable = false), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[scala.collection.immutable.Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) + CollectObjectsToMap( + p => deserializerFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), + getPath, + mirror.runtimeClass(t.typeSymbol.asClass) + ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4245b70892d1c..877328164a8a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -240,6 +240,7 @@ object FunctionRegistry { expression[Log1p]("log1p"), expression[Log2]("log2"), expression[Log]("ln"), + expression[Remainder]("mod"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), expression[Pmod]("pmod"), @@ -325,6 +326,7 @@ object FunctionRegistry { expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), expression[ParseUrl]("parse_url"), + expression[StringLocate]("position"), expression[FormatString]("printf"), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37a95..1f217390518a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: IfCoercion :: + StackCoercion :: Division :: PropagateTypes :: ImplicitTypeCasts :: @@ -648,6 +649,22 @@ object TypeCoercion { } } + /** + * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. + */ + object StackCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => + Stack(children.zipWithIndex.map { + // The first child is the number of rows for stack. + case (e, 0) => e + case (Literal(null, NullType), index: Int) => + Literal.create(null, s.findDataType(index)) + case (e, _) => e + }) + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 974ef900e2eed..12ba5aedde026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,6 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit + def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + def getTable(db: String, table: String): CatalogTable def getTableOption(db: String, table: String): Option[CatalogTable] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8a5319bebe54e..9820522a230e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -312,6 +312,15 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = schema) } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = synchronized { + requireTableExists(db, table) + val origTable = catalog(db).tables(table).table + catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + } + override def getTable(db: String, table: String): CatalogTable = synchronized { requireTableExists(db, table) catalog(db).tables(table).table diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b6744a7f53a54..cf02da8993658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -376,6 +376,19 @@ class SessionCatalog( schema.fields.map(_.name).exists(conf.resolver(_, colName)) } + /** + * Alter Spark's statistics of an existing metastore table identified by the provided table + * identifier. + */ + def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + externalCatalog.alterTableStats(db, table, newStats) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 2f328ccc49451..c043ed9c431b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -75,7 +75,7 @@ case class CatalogStorageFormat( CatalogUtils.maskCredentials(properties) match { case props if props.isEmpty => // No-op case props => - map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) + map.put("Storage Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) } map } @@ -316,7 +316,7 @@ case class CatalogTable( } } - if (properties.nonEmpty) map.put("Properties", tableProperties) + if (properties.nonEmpty) map.put("Table Properties", tableProperties) stats.foreach(s => map.put("Statistics", s.simpleString)) map ++= storage.toLinkedHashMap if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index af1eba26621bd..a54f6d0e11147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -988,7 +988,7 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, converterTerm, - s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s"$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") converterTerm @@ -1005,7 +1005,7 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") ctx.addMutableState(converterClassName, catalystConverterTerm, - s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,7 +1019,7 @@ case class ScalaUDF( val funcTerm = ctx.freshName("udf") ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") + s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2b252259b89d..ec6e6ba0f091b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,6 +320,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic Examples: > SELECT 2 _FUNC_ 1.8; 0.2 + > SELECT MOD(2, 1.8); + 0.2 """) case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fd9780245fcfb..5158949b95629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.commons.lang3.exception.ExceptionUtils import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile @@ -113,7 +112,7 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") term } @@ -202,16 +201,6 @@ class CodegenContext { partitionInitializationStatements.mkString("\n") } - /** - * Holding all the functions those will be added into generated class. - */ - val addedFunctions: mutable.Map[String, String] = - mutable.Map.empty[String, String] - - def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFunctions += ((funcName, funcCode)) - } - /** * Holds expressions that are equivalent. Used to perform subexpression elimination * during codegen. @@ -233,10 +222,118 @@ class CodegenContext { // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] - def declareAddedFunctions(): String = { - addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + val outerClassName = "OuterClass" + + /** + * Holds the class and instance names to be generated, where `OuterClass` is a placeholder + * standing for whichever class is generated as the outermost class and which will contain any + * nested sub-classes. All other classes and instance names in this list will represent private, + * nested sub-classes. + */ + private val classes: mutable.ListBuffer[(String, String)] = + mutable.ListBuffer[(String, String)](outerClassName -> null) + + // A map holding the current size in bytes of each class to be generated. + private val classSize: mutable.Map[String, Int] = + mutable.Map[String, Int](outerClassName -> 0) + + // Nested maps holding function names and their code belonging to each class. + private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = + mutable.Map(outerClassName -> mutable.Map.empty[String, String]) + + // Returns the size of the most recently added class. + private def currClassSize(): Int = classSize(classes.head._1) + + // Returns the class name and instance name for the most recently added class. + private def currClass(): (String, String) = classes.head + + // Adds a new class. Requires the class' name, and its instance name. + private def addClass(className: String, classInstance: String): Unit = { + classes.prepend(className -> classInstance) + classSize += className -> 0 + classFunctions += className -> mutable.Map.empty[String, String] + } + + /** + * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the + * function will be inlined into a new private, nested class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * simple `funcName` will be returned. + * + * @param funcName the class-unqualified name of the function + * @param funcCode the body of the function + * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This + * can be necessary when a function is declared outside of the context + * it is eventually referenced and a returned qualified function name + * cannot otherwise be accessed. + * @return the name of the function, qualified by class if it will be inlined to a private, + * nested sub-class + */ + def addNewFunction( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean = false): String = { + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1600k bytes to determine when a function should be inlined to a private, nested + // sub-class. + val (className, classInstance) = if (inlineToOuterClass) { + outerClassName -> "" + } else if (currClassSize > 1600000) { + val className = freshName("NestedClass") + val classInstance = freshName("nestedClassInstance") + + addClass(className, classInstance) + + className -> classInstance + } else { + currClass() + } + + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + + if (className == outerClassName) { + funcName + } else { + + s"$classInstance.$funcName" + } + } + + /** + * Instantiates all nested, private sub-classes as objects to the `OuterClass` + */ + private[sql] def initNestedClasses(): String = { + // Nested, private sub-classes have no mutable state (though they do reference the outer class' + // mutable state), so we declare and initialize them inline to the OuterClass. + classes.filter(_._1 != outerClassName).map { + case (className, classInstance) => + s"private $className $classInstance = new $className();" + }.mkString("\n") + } + + /** + * Declares all function code that should be inlined to the `OuterClass`. + */ + private[sql] def declareAddedFunctions(): String = { + classFunctions(outerClassName).values.mkString("\n") } + /** + * Declares all nested, private sub-classes and the function code that should be inlined to them. + */ + private[sql] def declareNestedClasses(): String = { + classFunctions.filterKeys(_ != outerClassName).map { + case (className, functions) => + s""" + |private class $className { + | ${functions.values.mkString("\n")} + |} + """.stripMargin + } + }.mkString("\n") + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -556,8 +653,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -573,8 +669,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => @@ -629,7 +724,9 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class to which the function would be inlined would grow + * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions * @param expressions the codes to evaluate expressions. @@ -689,7 +786,6 @@ class CodegenContext { |} """.stripMargin addNewFunction(name, code) - name } foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) @@ -773,8 +869,6 @@ class CodegenContext { |} """.stripMargin - addNewFunction(fnName, fn) - // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` // when it is code generated. This decision should be a cost based one. @@ -792,7 +886,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" + subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4d732445544a8..635766835029b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState("boolean", isNull, s"$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$isNull = ${ev.isNull}; - this.$value = ${ev.value}; + $isNull = ${ev.isNull}; + $value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$value = ${ev.value}; + $value = ${ev.value}; """ } } @@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(index).map { case (e, i) => - val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + val ev = ExprCode("", s"isNull_$i", s"value_$i") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } @@ -135,6 +135,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP $allUpdates return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f7fc2d54a047b..a31943255b995 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -179,6 +179,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR $comparisons return 0; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dcd1ed96a298e..b400783bb5e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -72,6 +72,9 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${eval.code} return !${eval.isNull} && ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b1cb6edefb852..f708aeff2b146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - this.$values = new Object[${schema.length}]; + $values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); - this.$values = null; + $values = null; """ ExprCode(code, "false", output) @@ -184,6 +184,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $allExpressions return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index efbbc038bd33b..6be69d119bf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") ctx.addMutableState(rowWriterClass, rowWriter, - s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, - s"this.$arrayWriter = new $arrayWriterClass();") + s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val element = ctx.freshName("element") @@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, holder, - s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + s"$holder = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -402,6 +402,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${eval.code.trim} return ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6675a84ece48..98c4cbee38dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -93,7 +93,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"this.$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -340,7 +340,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc }) + s""" final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; + $values = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ee365fe636614..ae8efb673f91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | $globalValue = ${ev.value}; |} """.stripMargin - ctx.addNewFunction(funcName, funcBody) - (funcName, globalIsNull, globalValue) + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e84796f2edad0..c217aa875d9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -138,6 +138,13 @@ case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + /** + * Return true iff the first child exists and has a foldable IntegerType. + */ + def hasFoldableNumRows: Boolean = { + children.nonEmpty && children.head.dataType == IntegerType && children.head.foldable + } + override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") @@ -156,6 +163,18 @@ case class Stack(children: Seq[Expression]) extends Generator { } } + def findDataType(index: Int): DataType = { + // Find the first data type except NullType. + val firstDataIndex = ((index - 1) % numFields) + 1 + for (i <- firstDataIndex until children.length by numFields) { + if (children(i).dataType != NullType) { + return children(i).dataType + } + } + // If all values of the column are NullType, use it. + NullType + } + override def elementSchema: StructType = StructType(children.tail.take(numFields).zipWithIndex.map { case (e, index) => StructField(s"col$index", e.dataType) @@ -181,7 +200,7 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -190,7 +209,7 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + s"${eval.code}\n$rowData[$row] = ${eval.value};" }) // Create the collection. @@ -198,7 +217,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ecf745c9..073993cccdf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -597,8 +598,8 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) => - // collection + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( @@ -609,6 +610,20 @@ case class MapObjects private( genValue => s"$builder.$$plus$$eq($genValue);", s"(${cls.getName}) $builder.result();" ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + genValue => s"$builder.add($genValue);", + s"$builder;" + ) case None => // array ( @@ -652,6 +667,173 @@ case class MapObjects private( } } +object CollectObjectsToMap { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjectsToMap case class. + * + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ + def apply( + keyFunction: Expression => Expression, + valueFunction: Expression => Expression, + inputData: Expression, + collClass: Class[_]): CollectObjectsToMap = { + val id = curId.getAndIncrement() + val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) + val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" + val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) + CollectObjectsToMap( + keyLoopValue, keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), + inputData, collClass) + } +} + +/** + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopValue the name of the loop variable that is used when iterating over the key + * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueLoopValue the name of the loop variable that is used when iterating over the value + * collection, and which is used as input for the `valueLambdaFunction` + * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ +case class CollectObjectsToMap private( + keyLoopValue: String, + keyLambdaFunction: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLambdaFunction: Expression, + inputData: Expression, + collClass: Class[_]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = + keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(dataType: DataType) = dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => dataType + } + + val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] + val keyElementJavaType = ctx.javaType(mapType.keyType) + ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = ctx.javaType(mapType.valueType) + ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") + + val getLength = s"${genInputData.value}.numElements()" + + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val getKeyArray = + s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getValueArray = + s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + int $dataLength = $getLength; + $constructBuilder + $getKeyArray + $getValueArray + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $appendToBuilder + + $loopIndex += 1; + } + + $getBuilderResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -981,7 +1163,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val code = s""" ${instanceGen.code} - this.${javaBeanInstance} = ${instanceGen.value}; + ${javaBeanInstance} = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 035a1afe8b782..717ada225a4f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -654,8 +654,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: """, extended = """ Examples: + > SELECT _FUNC_('bar', 'foobarbar'); + 4 > SELECT _FUNC_('bar', 'foobarbar', 5); 7 + > SELECT POSITION('bar' IN 'foobarbar'); + 4 """) // scalastyle:on line.size.limit case class StringLocate(substr: Expression, str: Expression, start: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 7930515038355..1fd680ab64b5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -81,7 +81,7 @@ private[sql] class JSONOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d16689a34298a..3ab70fb90470c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -77,12 +77,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin(conf), + EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, // Operator combine CollapseRepartition, CollapseProject, @@ -102,7 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters(conf), + PruneFilters, EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -619,14 +619,15 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { - inferFilters(plan) - } else { - plan - } +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + } private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => @@ -717,7 +718,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -730,7 +731,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) + cond.deterministic && p.constraints.contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 51f749a8bf857..66b8ca62e5e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -383,22 +383,27 @@ object LikeSimplification extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(input, Literal(pattern, StringType)) => - pattern.toString match { - case startsWith(prefix) if !prefix.endsWith("\\") => - StartsWith(input, Literal(prefix)) - case endsWith(postfix) => - EndsWith(input, Literal(postfix)) - // 'a%a' pattern is basically same with 'a%' && '%a'. - // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => - And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith("\\") => - Contains(input, Literal(infix)) - case equalTo(str) => - EqualTo(input, Literal(str)) - case _ => - Like(input, Literal.create(pattern, StringType)) + if (pattern == null) { + // If pattern is null, return null value directly, since "col like null" == null. + Literal(null, BooleanType) + } else { + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) + case _ => + Like(input, Literal.create(pattern, StringType)) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2fe3039774423..bb97e2c808b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -113,7 +113,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ - filter.getConstraints(conf.constraintPropagationEnabled) + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a16611af28a7d..500d999c30da7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1076,6 +1076,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() } + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5ba043e17a128..9130b14763e24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,194 +19,18 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { - self: PlanType => - - def output: Seq[Attribute] - - /** - * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. - */ - protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && - constraint.deterministic) - } - - /** - * Infers a set of `isNotNull` constraints from null intolerant expressions as well as - * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this - * returns a constraint of the form `isNotNull(a)` - */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { - // First, we propagate constraints from the null intolerant expressions. - var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) - - // Second, we infer additional constraints from non-nullable attributes that are part of the - // operator's output - val nonNullableAttributes = output.filterNot(_.nullable) - isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet - - isNotNullConstraints -- constraints - } - - /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. - */ - private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = - constraint match { - // When the root is IsNotNull, we can push IsNotNull through the child null intolerant - // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // Constraints always return true for all the inputs. That means, null will never be returned. - // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child - // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) - } - - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) - case _ => Seq.empty[Attribute] - } - - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias => (a.toAttribute, a.child) - } ++ children.flatMap(_.aliasMap)) - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - * - * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` - * as they are often useless and can lead to a non-converging set of constraints. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val constraintClasses = generateEquivalentConstraintClasses(constraints) - - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(l) && - !isRecursiveDeduction(r, constraintClasses) => r - }) - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(r) && - !isRecursiveDeduction(l, constraintClasses) => l - }) - case _ => // No inference - } - inferredConstraints -- constraints - } - - /* - * Generate a sequence of expression sets from constraints, where each set stores an equivalence - * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following - * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal - * to an selected attribute. - */ - private def generateEquivalentConstraintClasses( - constraints: Set[Expression]): Seq[Set[Expression]] = { - var constraintClasses = Seq.empty[Set[Expression]] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - // Transform [[Alias]] to its child. - val left = aliasMap.getOrElse(l, l) - val right = aliasMap.getOrElse(r, r) - // Get the expression set for an equivalence constraint class. - val leftConstraintClass = getConstraintClass(left, constraintClasses) - val rightConstraintClass = getConstraintClass(right, constraintClasses) - if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { - // Combine the two sets. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ - (leftConstraintClass ++ rightConstraintClass) - } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty - // Update equivalence class of `left` expression. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) - } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty - // Update equivalence class of `right` expression. - constraintClasses = constraintClasses - .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) - } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty - // Create new equivalence constraint class since neither expression presents - // in any classes. - constraintClasses = constraintClasses :+ Set(left, right) - } - case _ => // Skip - } +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with QueryPlanConstraints[PlanType] { - constraintClasses - } - - /* - * Get all expressions equivalent to the selected expression. - */ - private def getConstraintClass( - expr: Expression, - constraintClasses: Seq[Set[Expression]]): Set[Expression] = - constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) - - /* - * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it - * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. - * Here we first get all expressions equal to `attr` and then check whether at least one of them - * is a child of the referenced expression. - */ - private def isRecursiveDeduction( - attr: Attribute, - constraintClasses: Seq[Set[Expression]]): Boolean = { - val expr = aliasMap.getOrElse(attr, attr) - getConstraintClass(expr, constraintClasses).exists { e => - expr.children.exists(_.semanticEquals(e)) - } - } - - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + self: PlanType => - /** - * Returns [[constraints]] depending on the config of enabling constraint propagation. If the - * flag is disabled, simply returning an empty constraints. - */ - private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = - if (constraintPropagationEnabled) { - constraints - } else { - ExpressionSet(Set.empty) - } + def conf: SQLConf = SQLConf.get - /** - * This method can be overridden by any child class of QueryPlan to specify a set of constraints - * based on the given operator's constraint propagation logic. These constraints are then - * canonicalized and filtered automatically to contain only those attributes that appear in the - * [[outputSet]]. - * - * See [[Canonicalize]] for more details. - */ - protected def validConstraints: Set[Expression] = Set.empty + def output: Seq[Attribute] /** * Returns the set of attributes that are output by this node. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala new file mode 100644 index 0000000000000..b08a009f0dca1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.expressions._ + + +trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => + + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = { + if (conf.constraintPropagationEnabled) { + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) + } else { + ExpressionSet(Set.empty) + } + } + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. + */ + protected def validConstraints: Set[Expression] = Set.empty + + /** + * Infers a set of `isNotNull` constraints from null intolerant expressions as well as + * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this + * returns a constraint of the form `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) + + // Second, we infer additional constraints from non-nullable attributes that are part of the + // operator's output + val nonNullableAttributes = output.filterNot(_.nullable) + isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet + + isNotNullConstraints -- constraints + } + + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _ => Seq.empty[Attribute] + } + + // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so + // we may avoid producing recursive constraints. + private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( + expressions.collect { + case a: Alias => (a.toAttribute, a.child) + } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints[PlanType]].aliasMap)) + + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val constraintClasses = generateEquivalentConstraintClasses(constraints) + + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, constraintClasses) => r + }) + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, constraintClasses) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints + } + + /** + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { + // Combine the two sets. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty + // Update equivalence class of `left` expression. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty + // Update equivalence class of `right` expression. + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) + } + case _ => // Skip + } + + constraintClasses + } + + /** + * Get all expressions equivalent to the selected expression. + */ + private def getConstraintClass( + expr: Expression, + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) + + /** + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is a child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + constraintClasses: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getConstraintClass(expr, constraintClasses).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index df66f9a082aee..7375a0bcbae75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = f(arg1.asInstanceOf[BaseType]) - val newChild2 = f(arg2.asInstanceOf[BaseType]) + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { changed = true (newChild1, newChild2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index efb42292634ad..746c3e8950f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp - * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * and are stored internally as longs, which are capable of storing timestamps with microsecond * precision. */ object DateTimeUtils { @@ -399,13 +399,14 @@ object DateTimeUtils { digitsMilli += 1 } - if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { - return None + // We are truncating the nanosecond part, which results in loss of precision + while (digitsMilli > 6) { + segments(6) /= 10 + digitsMilli -= 1 } - // Instead of return None, we truncate the fractional seconds to prevent inserting NULL - if (segments(6) > 999999) { - segments(6) = segments(6).toString.take(6).toInt + if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { + return None } if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3ea808926e10b..6ab3a615e6cc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable @@ -64,6 +65,47 @@ object SQLConf { } } + /** + * Default config. Only used when there is no active SparkSession for the thread. + * See [[get]] for more information. + */ + private val fallbackConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = new SQLConf + } + + /** See [[get]] for more information. */ + def getFallbackConf: SQLConf = fallbackConf.get() + + /** + * Defines a getter that returns the SQLConf within scope. + * See [[get]] for more information. + */ + private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get()) + + /** + * Sets the active config object within the current scope. + * See [[get]] for more information. + */ + def setSQLConfGetter(getter: () => SQLConf): Unit = { + confGetter.set(getter) + } + + /** + * Returns the active config object within the current scope. If there is an active SparkSession, + * the proper SQLConf associated with the thread's session is used. + * + * The way this works is a little bit convoluted, due to the fact that config was added initially + * only for physical plans (and as a result not in sql/catalyst module). + * + * The first time a SparkSession is instantiated, we set the [[confGetter]] to return the + * active SparkSession's config. If there is no active SparkSession, it returns using the thread + * local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf) + * is to support setting different config options for different threads so we can potentially + * run tests in parallel. At the time this feature was implemented, this was a no-op since we + * run unit tests (that does not involve SparkSession) in serial order. + */ + def get: SQLConf = confGetter.get()() + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") @@ -552,12 +594,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val SUBQUERY_REUSE_ENABLED = buildConf("spark.sql.subquery.reuse") - .internal() - .doc("When true, the planner will try to find out duplicated subqueries and re-use them.") - .booleanConf - .createWithDefault(true) - val STATE_STORE_PROVIDER_CLASS = buildConf("spark.sql.streaming.stateStore.providerClass") .internal() @@ -938,8 +974,6 @@ class SQLConf extends Serializable with Logging { def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) - def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED) - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064f93ebc..ff2414b174acb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary map types") { + val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + assert(mapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mapDeserializer = deserializerFor[Map[Int, Int]] + assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) + + import scala.collection.immutable.HashMap + val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + assert(hashMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) + + import scala.collection.mutable.{LinkedHashMap => LHMap} + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + assert(linkedHashMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d..7358f401ed520 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest { } test("coalesce casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1.0) - :: Literal(1) - :: Literal.create(1.0, FloatType) - :: Nil), - Coalesce(Cast(Literal(1.0), DoubleType) - :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) - :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1L) - :: Literal(1) - :: Literal(new java.math.BigDecimal("1000000000000000000000")) - :: Nil), - Coalesce(Cast(Literal(1L), DecimalType(22, 0)) - :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) - :: Nil)) + val rule = TypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(Cast(doubleLit, DoubleType), + Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), + Cast(intLit, FloatType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), + Cast(doubleLit, StringType), Cast(stringLit, StringType)))) } test("CreateArray casts") { @@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = TypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest { If(Literal.create(null, BooleanType), Literal(1), Literal(1))) ruleTest(rule, - If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) ruleTest(rule, - If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, Cast(doubleLit, StringType))) + + ruleTest(rule, + If(trueLit, timestampLit, stringLit), + If(trueLit, Cast(timestampLit, StringType), stringLit)) } test("type coercion for CaseKeyWhen") { @@ -714,6 +768,63 @@ class TypeCoercionSuite extends PlanTest { ) } + test("type coercion for Stack") { + val rule = TypeCoercion.StackCoercion + + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))), + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))), + Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))), + Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3")))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))), + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal(null), + Literal(null), Literal("2"))), + Stack(Seq(Literal(2), + Literal(1), Literal.create(null, StringType), + Literal.create(null, IntegerType), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(1), + Literal("2"), Literal(null))), + Stack(Seq(Literal(2), + Literal.create(null, StringType), Literal(1), + Literal("2"), Literal.create(null, IntegerType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(null), + Literal(1), Literal("2"))), + Stack(Seq(Literal(2), + Literal.create(null, IntegerType), Literal.create(null, StringType), + Literal(1), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 1759ac04c0033..557b0970b54e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -245,7 +245,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter table schema") { val catalog = newBasicCatalog() - val tbl1 = catalog.getTable("db2", "tbl1") val newSchema = StructType(Seq( StructField("col1", IntegerType), StructField("new_field_2", StringType), @@ -256,6 +255,16 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(newTbl1.schema == newSchema) } + test("alter table stats") { + val catalog = newBasicCatalog() + val oldTableStats = catalog.getTable("db2", "tbl1").stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats("db2", "tbl1", newStats) + val newTableStats = catalog.getTable("db2", "tbl1").stats + assert(newTableStats.get == newStats) + } + test("get table") { assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 5afeb0e8ca032..dce73b3635e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -448,6 +448,18 @@ abstract class SessionCatalogSuite extends PlanTest { } } + test("alter table stats") { + withBasicCatalog { catalog => + val tableId = TableIdentifier("tbl1", Some("db2")) + val oldTableStats = catalog.getTableMetadata(tableId).stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats(tableId, newStats) + val newTableStats = catalog.getTableMetadata(tableId).stats + assert(newTableStats.get == newStats) + } + } + test("alter table add columns") { withBasicCatalog { sessionCatalog => sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6af0cde73538b..f4d5a4471d896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -23,6 +23,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts.implicitCast import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -223,6 +224,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def f: (Double) => Double = (x: Double) => 1 / math.tan(x) testUnary(Cot, f) checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType) + val nullLit = Literal.create(null, NullType) + val intNullLit = Literal.create(null, IntegerType) + val intLit = Literal.create(1, IntegerType) + checkEvaluation(checkDataTypeAndCast(Cot(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intLit)), 1 / math.tan(1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(-intLit)), 1 / math.tan(-1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(0)), 1 / math.tan(0), EmptyRow) } test("atan") { @@ -250,6 +259,11 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } + def checkDataTypeAndCast(expression: UnaryMathExpression): Expression = { + val expNew = implicitCast(expression.child, expression.inputTypes(0)).getOrElse(expression) + expression.withNewChildren(Seq(expNew)) + } + test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) @@ -262,12 +276,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Ceil(doublePi), 4L, EmptyRow) - checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) - checkEvaluation(Ceil(longLit), longLit, EmptyRow) - checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) - checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) - checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Ceil(doublePi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatPi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-doublePi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-floatPi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Ceil(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0.01)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-0.10)), 0L, EmptyRow) } test("floor") { @@ -282,12 +306,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Floor(doublePi), 3L, EmptyRow) - checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) - checkEvaluation(Floor(longLit), longLit, EmptyRow) - checkEvaluation(Floor(-doublePi), -4L, EmptyRow) - checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) - checkEvaluation(Floor(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Floor(doublePi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatPi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-doublePi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-floatPi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Floor(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0.01)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-0.10)), -1L, EmptyRow) } test("factorial") { @@ -541,10 +575,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intPi: Int = 314159265 val longPi: Long = 31415926535897932L val bdPi: BigDecimal = BigDecimal(31415927L, 7) + val floatPi: Float = 3.1415f val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593) + val floatResults: Seq[Float] = Seq(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 3.1f, 3.14f, + 3.141f, 3.1415f, 3.1415f, 3.1415f) + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ Seq.fill[Short](7)(31415) @@ -563,10 +601,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow) checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(floatPi, scale), floatResults(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 5064a1f63f83d..394c0a091e390 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.01f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) + + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index b69b74b4240bd..58ea5b9cb52d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite { test("generated projections on wider table") { val N = 1000 - val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( - (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) @@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite { val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => - val s = UTF8String.fromString((i + 1).toString) - assert(i + 1 === unsafe.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } @@ -62,13 +62,63 @@ class GeneratedProjectionSuite extends SparkFunSuite { val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => - val r = i + 1 - val s = UTF8String.fromString((i + 1).toString) - assert(r === result.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) - assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(i === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + + test("SPARK-18016: generated projections on wider table requiring class-splitting") { + val N = 4000 + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index b29e1cbd14943..2a04bd588dc1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -37,7 +37,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c275f997ba6e9..1df0a89cf0bf1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -38,7 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9a4bcdb011435..cdc9f25cf8777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,20 +32,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, CombineFilters, BooleanSimplification) :: Nil } - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - PushPredicateThroughJoin, - PushDownPredicate, - InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), - CombineFilters) :: Nil - } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("filter: filter out constraints in condition") { @@ -215,8 +206,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) - comparePlans(optimized, originalQuery) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index fdde89d079bc0..50398788c605c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, StringType} class LikeSimplificationSuite extends PlanTest { @@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("null pattern") { + val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index b7136703b7541..a37bc4bca2422 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,16 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf), - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + EliminateOuterJoin, PushPredicateThroughJoin) :: Nil } @@ -243,19 +234,25 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - // The predicate "x.b + y.d >= 3" will be inferred constraints like: - // "x.b != null" and "y.d != null", if constraint propagation is enabled. - // When we disable it, the predicate can't be evaluated on left or right plan and used to - // filter out nulls. So the Outer Join will not be eliminated. - val originalQuery = + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) .where("x.b".attr + "y.d".attr >= 3) - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 38dff4733f714..2285be16938d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf), + PruneFilters, PropagateEmptyRelation) :: Nil } @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 741dd0cf428d0..706634cdd29b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,18 +35,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(conf), - PushDownPredicate, - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Filter Pushdown and Pruning", Once, - CombineFilters, - PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + PruneFilters, PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -159,15 +149,19 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - val optimized = - OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) - // When constraint propagation is disabled, the useless filter won't be pruned. - // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant - // and duplicate filters. - val correctAnswer = tr1 - .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100).where('d.attr < 100), + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + try { + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 756e0f35b2178..21b7f49e14bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -34,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index f33abc5b2e049..76be6ee3f50bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at") + "int", "smallint", "timestamp", "at", "position") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 4061394b862a6..a3948d90b0e4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -399,20 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + try { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - verifyConstraints( - filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), - filterRelation.analyze.constraints) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(filterRelation.analyze.constraints.nonEmpty) - assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(filterRelation.analyze.constraints.isEmpty) - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), - aliasedRelation.analyze.constraints) - assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(aliasedRelation.analyze.constraints.nonEmpty) + + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(aliasedRelation.analyze.constraints.isEmpty) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 712841835acd5..819078218c546 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } -case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { +case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def children: Seq[Expression] = map.values.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true } +case class SeqTupleExpression(sons: Seq[(Expression, Expression)], + nonSons: Seq[(Expression, Expression)]) extends Unevaluable { + override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2)) + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + case class JsonTestTreeNode(arg: Any) extends LeafNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === Dummy(None)) } + test("mapChildren should only works on children") { + val children = Seq((Literal(1), Literal(2))) + val nonChildren = Seq((Literal(3), Literal(4))) + val before = SeqTupleExpression(children, nonChildren) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren) + + val actual = before mapChildren toZero + assert(actual === expect) + } + test("preserves origin") { CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 9799817494f15..c8cf16d937352 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -34,6 +34,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } + test("nanoseconds truncation") { + def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) { + val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime)) + assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly") + assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime) + } + + checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456") + checkStringToTimestamp("2015-01-02 00:00:00.100000009", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000050000", "2015-01-02 00:00:00.00005") + checkStringToTimestamp("2015-01-02 00:00:00.12005", "2015-01-02 00:00:00.12005") + checkStringToTimestamp("2015-01-02 00:00:00.100", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000456789", "2015-01-02 00:00:00.000456") + checkStringToTimestamp("1950-01-02 00:00:00.000456789", "1950-01-02 00:00:00.000456") + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fe4be963e8184..7327c9b0c9c50 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -183,6 +183,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.codehaus.mojo build-helper-maven-plugin diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index cd521c52d1b21..8fea46a58e857 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -63,8 +63,6 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer; - private final boolean enablePerfMetrics; - /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. @@ -87,7 +85,6 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. - * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, @@ -95,15 +92,13 @@ public UnsafeFixedWidthAggregationMap( StructType groupingKeySchema, TaskMemoryManager taskMemoryManager, int initialCapacity, - long pageSizeBytes, - boolean enablePerfMetrics) { + long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); - this.enablePerfMetrics = enablePerfMetrics; + new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); @@ -223,15 +218,11 @@ public void free() { map.free(); } - @SuppressWarnings("UseOfSystemOutOrSystemErr") - public void printPerfMetrics() { - if (!enablePerfMetrics) { - throw new IllegalStateException("Perf metrics not enabled"); - } - System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); - System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); - System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); - System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + /** + * Gets the average hash map probe per looking up for the underlying `BytesToBytesMap`. + */ + public double getAverageProbesPerLookup() { + return map.getAverageProbesPerLookup(); } /** diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 27d32b5dca431..0c5f3f22e31e8 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0f96e82cedf4e..a1d8b7f4af1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -295,7 +295,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads JSON files and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -335,7 +335,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -537,7 +537,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index c856d3099f6ee..531c613afb0dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -551,7 +551,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { ) } - singleCol.queryExecution.toRdd.aggregate(zero)( + singleCol.queryExecution.toRdd.treeAggregate(zero)( (filter: BloomFilter, row: InternalRow) => { updater(filter, row) filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 17671ea8685b9..86574e2f71d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Maps + /** @since 2.3.0 */ + implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d2bf350711936..2c38f7d7c88da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -87,6 +87,11 @@ class SparkSession private( sparkContext.assertNotStopped() + // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. + SQLConf.setSQLConfGetter(() => { + SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + }) + /** * The version of Spark on which this application is running. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index e86116680a57a..74a47da2deef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -93,7 +93,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val nextBatch = ctx.freshName("nextBatch") - ctx.addNewFunction(nextBatch, + val nextBatchFuncName = ctx.addNewFunction(nextBatch, s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); @@ -121,7 +121,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } s""" |if ($batch == null) { - | $nextBatch(); + | $nextBatchFuncName(); |} |while ($batch != null) { | int $numRows = $batch.numRows(); @@ -133,7 +133,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | } | $idx = $numRows; | $batch = null; - | $nextBatch(); + | $nextBatchFuncName(); |} |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); |$scanTimeTotalNs = 0; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index f98ae82574d20..ff71fd4dc7bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -141,7 +141,7 @@ case class SortExec( ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") - ctx.addNewFunction(addToSorter, + val addToSorterFuncName = ctx.addNewFunction(addToSorter, s""" | private void $addToSorter() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -160,7 +160,7 @@ case class SortExec( s""" | if ($needToSort) { | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorter(); + | $addToSorterFuncName(); | $sortedIterator = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ac30b11557adb..0bd28e36135c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -357,6 +357,9 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co protected void processNext() throws java.io.IOException { ${code.trim} } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 68c8e6ce62cbb..5027a615ced7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -59,7 +59,8 @@ case class HashAggregateExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), - "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), + "avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -93,6 +94,7 @@ case class HashAggregateExec( val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") + val avgHashmapProbe = longMetric("avgHashmapProbe") child.execute().mapPartitions { iter => @@ -116,7 +118,8 @@ case class HashAggregateExec( testFallbackStartsAt, numOutputRows, peakMemory, - spillSize) + spillSize, + avgHashmapProbe) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) @@ -157,7 +160,7 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer + // The variables used as aggregation buffer. Only used for aggregation without keys. private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { @@ -209,7 +212,7 @@ case class HashAggregateExec( } val doAgg = ctx.freshName("doAggregateWithoutKey") - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer @@ -226,7 +229,7 @@ case class HashAggregateExec( | while (!$initAgg) { | $initAgg = true; | long $beforeAgg = System.nanoTime(); - | $doAgg(); + | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result @@ -312,8 +315,7 @@ case class HashAggregateExec( groupingKeySchema, TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity - TaskContext.get().taskMemoryManager().pageSizeBytes, - false // disable tracking of performance metrics + TaskContext.get().taskMemoryManager().pageSizeBytes ) } @@ -341,7 +343,8 @@ case class HashAggregateExec( hashMap: UnsafeFixedWidthAggregationMap, sorter: UnsafeKVExternalSorter, peakMemory: SQLMetric, - spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + spillSize: SQLMetric, + avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes @@ -351,6 +354,10 @@ case class HashAggregateExec( peakMemory.add(maxMemory) metrics.incPeakExecutionMemory(maxMemory) + // Update average hashmap probe + val avgProbes = hashMap.getAverageProbesPerLookup() + avgHashmapProbe.add(avgProbes.ceil.toLong) + if (sorter == null) { // not spilled return hashMap.iterator() @@ -577,6 +584,7 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") + val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe") def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -592,7 +600,7 @@ case class HashAggregateExec( } else "" } - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" ${generateGenerateCode} private void $doAgg() throws java.io.IOException { @@ -602,7 +610,8 @@ case class HashAggregateExec( ${if (isFastHashMapEnabled) { s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} - $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize); + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, + $avgHashmapProbe); } """) @@ -672,7 +681,7 @@ case class HashAggregateExec( if (!$initAgg) { $initAgg = true; long $beforeAgg = System.nanoTime(); - $doAgg(); + $doAggFuncName(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); } @@ -792,6 +801,8 @@ case class HashAggregateExec( | $unsafeRowBuffer = | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); | } + | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + | // aggregation after processing all input rows. | if ($unsafeRowBuffer == null) { | if ($sorterTerm == null) { | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); @@ -800,7 +811,7 @@ case class HashAggregateExec( | } | $resetCounter | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. + | // try to allocate buffer again. | $unsafeRowBuffer = | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); | if ($unsafeRowBuffer == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 2988161ee5e7b..8efa95d48aea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -88,7 +88,8 @@ class TungstenAggregationIterator( testFallbackStartsAt: Option[(Int, Int)], numOutputRows: SQLMetric, peakMemory: SQLMetric, - spillSize: SQLMetric) + spillSize: SQLMetric, + avgHashmapProbe: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -162,8 +163,7 @@ class TungstenAggregationIterator( StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity - TaskContext.get().taskMemoryManager().pageSizeBytes, - false // disable tracking of performance metrics + TaskContext.get().taskMemoryManager().pageSizeBytes ) // The function used to read and process input rows. When processing input rows, @@ -420,6 +420,10 @@ class TungstenAggregationIterator( peakMemory += maxMemory spillSize += metrics.memoryBytesSpilled - spillSizeBefore metrics.incPeakExecutionMemory(maxMemory) + + // Update average hashmap probe if this is the last record. + val averageProbes = hashMap.getAverageProbesPerLookup() + avgHashmapProbe.add(averageProbes.ceil.toLong) } numOutputRows += 1 res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f69a688555bbf..f3ca8397047fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} @@ -281,10 +281,8 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") ctx.copyResult = true - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSampler();") - ctx.addNewFunction(initSampler, + val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" | private void $initSampler() { | $sampler = new $samplerClass($upperBound - $lowerBound, false); @@ -299,6 +297,9 @@ case class SampleExec( | } """.stripMargin.trim) + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSamplerFuncName();") + val samplingCount = ctx.freshName("samplingCount") s""" | int $samplingCount = $sampler.sample(); @@ -347,8 +348,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } override def inputRDDs(): Seq[RDD[InternalRow]] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) :: Nil + val rdd = if (start == end || (start < end ^ 0 < step)) { + new EmptyRDD[InternalRow](sqlContext.sparkContext) + } else { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + rdd :: Nil } protected override def doProduce(ctx: CodegenContext): String = { @@ -390,7 +395,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // The default size of a batch, which must be positive integer val batchSize = 1000 - ctx.addNewFunction("initRange", + val initRangeFuncName = ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); @@ -447,7 +452,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | // initialize Range | if (!$initTerm) { | $initTerm = true; - | initRange(partitionIndex); + | $initRangeFuncName(partitionIndex); | } | | while (true) { @@ -595,9 +600,6 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { - // Ignore this wrapper for canonicalizing. - override lazy val canonicalized: SparkPlan = child.canonicalized - override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 14024d6c10558..d3fa0dcd2d7c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -128,9 +128,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } else { val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) - var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => - groupedAccessorsLength += 1 + val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) => val funcName = s"accessors$i" val funcCode = s""" |private void $funcName() { @@ -139,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => + val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -148,8 +146,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), - (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"), + extractorNames.map { extractorName => s"$extractorName();"}.mkString("\n")) } val codeBody = s""" @@ -224,6 +222,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2de14c90ec757..2f273b63e8348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTable(tableMeta.copy(stats = Some(statistics))) + sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3183c7911b1fb..3c59b982c2dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -69,7 +69,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTable(tableMeta.copy(stats = newStats)) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 793fb9b795596..f924b3d914635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -21,7 +21,6 @@ import java.util.Locale import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.forkjoin.ForkJoinPool import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -236,7 +235,7 @@ case class AlterTableSetPropertiesCommand( // direct property. val newTable = table.copy( properties = table.properties ++ properties, - comment = properties.get("comment")) + comment = properties.get("comment").orElse(table.comment)) catalog.alterTable(newTable) Seq.empty[Row] } @@ -588,8 +587,15 @@ case class AlterTableRecoverPartitionsCommand( val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) - val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(), - table.partitionColumnNames, threshold, spark.sessionState.conf.resolver) + + val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) + val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = + try { + scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, + spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + } finally { + evalPool.shutdown() + } val total = partitionSpecsAndLocs.length logInfo(s"Found $total partitions in $root") @@ -610,8 +616,6 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty[Row] } - @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) - private def scanPartitions( spark: SparkSession, fs: FileSystem, @@ -620,7 +624,8 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = { + resolver: Resolver, + evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } @@ -644,7 +649,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver) + partitionNames.drop(1), threshold, resolver, evalTaskSupport) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 905b8683e10bd..f5df1848a38c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -59,8 +60,11 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) - + // Change table stats based on the sizeInBytes of pruned files + val withStats = logicalRelation.catalogTable.map(_.copy( + stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) + val prunedLogicalRelation = logicalRelation.copy( + relation = prunedFsRelation, catalogTable = withStats) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) val filter = Filter(filterExpression, prunedLogicalRelation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 76f121c0c955f..eadc6c94f4b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -111,8 +111,8 @@ abstract class CSVDataSource extends Serializable { object CSVDataSource { def apply(options: CSVOptions): CSVDataSource = { - if (options.wholeFile) { - WholeFileCSVDataSource + if (options.multiLine) { + MultiLineCSVDataSource } else { TextInputCSVDataSource } @@ -197,7 +197,7 @@ object TextInputCSVDataSource extends CSVDataSource { } } -object WholeFileCSVDataSource extends CSVDataSource { +object MultiLineCSVDataSource extends CSVDataSource { override val isSplitable: Boolean = false override def readFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 78c16b75ee684..a13a5a34b4a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -128,7 +128,7 @@ class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 4f2963da9ace9..5a92a71d19e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -86,8 +86,8 @@ abstract class JsonDataSource extends Serializable { object JsonDataSource { def apply(options: JSONOptions): JsonDataSource = { - if (options.wholeFile) { - WholeFileJsonDataSource + if (options.multiLine) { + MultiLineJsonDataSource } else { TextInputJsonDataSource } @@ -147,7 +147,7 @@ object TextInputJsonDataSource extends JsonDataSource { } } -object WholeFileJsonDataSource extends JsonDataSource { +object MultiLineJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb6103953fc..8445c26eeee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -478,7 +478,7 @@ case class SortMergeJoinExec( | } | return false; // unreachable |} - """.stripMargin) + """.stripMargin, inlineToOuterClass = true) (leftRow, matches) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..73a0f8735ed45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { protected boolean stopEarly() { return $stopEarly; } - """) + """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index ef982a4ebd10d..49cab04de2bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -68,11 +68,11 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato } } - object SQLMetrics { private val SUM_METRIC = "sum" private val SIZE_METRIC = "size" private val TIMING_METRIC = "timing" + private val AVERAGE_METRIC = "average" def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) @@ -102,6 +102,22 @@ object SQLMetrics { acc } + /** + * Create a metric to report the average information (including min, med, max) like + * avg hashmap probe. Because `SQLMetric` stores long values, we take the ceil of the average + * values before storing them. This metric is used to record an average value computed in the + * end of a task. It should be set once. The initial values (zeros) of this metrics will be + * excluded after. + */ + def createAverageMetric(sc: SparkContext, name: String): SQLMetric = { + // The final result of this metric in physical operator UI may looks like: + // probe avg (min, med, max): + // (1, 2, 6) + val acc = new SQLMetric(AVERAGE_METRIC) + acc.register(sc, name = Some(s"$name (min, med, max)"), countFailedValues = false) + acc + } + /** * A function that defines how we aggregate the final accumulator results among all tasks, * and represent it in string for a SQL physical operator. @@ -110,6 +126,20 @@ object SQLMetrics { if (metricsType == SUM_METRIC) { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) + } else if (metricsType == AVERAGE_METRIC) { + val numberFormat = NumberFormat.getIntegerInstance(Locale.US) + + val validValues = values.filter(_ > 0) + val Seq(min, med, max) = { + val metric = if (validValues.isEmpty) { + Seq.fill(3)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(numberFormat.format) + } + s"\n($min, $med, $max)" } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { Utils.bytesToString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..e61a8eb628891 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = + (shortName(), RateSourceProvider.SCHEMA) + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 2abeadfe45362..d11045fb6ac8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -156,7 +156,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.subqueryReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8d0a8c2178803..8d2e1f32da059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1210,7 +1210,7 @@ object functions { /** * Creates a new struct column. * If the input column is a column in a `DataFrame`, or a derived column expression - * that is named (i.e. aliased), its name would be remained as the StructField's name, + * that is named (i.e. aliased), its name would be retained as the StructField's name, * otherwise, the newly generated StructField's name would be auto generated as * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 766776230257d..7e8e6394b4862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -163,7 +163,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a JSON file stream and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -205,7 +205,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -276,7 +276,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc8b4..4ca3b6406a328 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1399,4 +1399,65 @@ public void testSerializeNull() { ds1.map((MapFunction) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List beans = Collections.singletonList(bean); + Dataset dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList arrayList; + private LinkedList linkedList; + private List list; + + public ArrayList getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList arrayList) { + this.arrayList = arrayList; + } + + public LinkedList getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList linkedList) { + this.linkedList = linkedList; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 6de4cf0d5afa1..91b966829f8fb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,4 +1,5 @@ CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS COMMENT 'table_comment'; @@ -13,6 +14,8 @@ CREATE TEMPORARY VIEW temp_Data_Source_View CREATE VIEW v AS SELECT * FROM t; +ALTER TABLE t SET TBLPROPERTIES (e = '3'); + ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 7e3b86b76a34a..3934620577e99 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -65,8 +65,18 @@ select ceiling(0); select ceiling(1); select ceil(1234567890123456); select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); -- floor select floor(0); select floor(1); select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001; + +-- mod +select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index d82df11251c5b..20c0390664037 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -15,3 +15,6 @@ select replace('abc', 'b'); -- uuid select length(uuid()), (uuid() <> uuid()); + +-- position +select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null); diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out index eece00d603db4..4bf4633491bd9 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -57,7 +57,7 @@ Last Access [not included in comparison] Type MANAGED Provider parquet Comment modified comment -Properties [type=parquet] +Table Properties [type=parquet] Location [not included in comparison]sql/core/spark-warehouse/table_with_comment diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 46d32bbc52247..ab9f2783f06bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,10 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 32 -- !query 0 CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS COMMENT 'table_comment' -- !query 0 schema @@ -42,7 +43,7 @@ struct<> -- !query 4 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +ALTER TABLE t SET TBLPROPERTIES (e = '3') -- !query 4 schema struct<> -- !query 4 output @@ -50,10 +51,18 @@ struct<> -- !query 5 -DESCRIBE t +ALTER TABLE t ADD PARTITION (c='Us', d=1) -- !query 5 schema -struct +struct<> -- !query 5 output + + + +-- !query 6 +DESCRIBE t +-- !query 6 schema +struct +-- !query 6 output a string b int c string @@ -64,11 +73,11 @@ c string d string --- !query 6 +-- !query 7 DESC default.t --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output a string b int c string @@ -79,11 +88,11 @@ c string d string --- !query 7 +-- !query 8 DESC TABLE t --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output a string b int c string @@ -94,11 +103,11 @@ c string d string --- !query 8 +-- !query 9 DESC FORMATTED t --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output a string b int c string @@ -119,15 +128,17 @@ Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 9 +-- !query 10 DESC EXTENDED t --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output a string b int c string @@ -148,15 +159,17 @@ Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 10 +-- !query 11 DESC t PARTITION (c='Us', d=1) --- !query 10 schema +-- !query 11 schema struct --- !query 10 output +-- !query 11 output a string b int c string @@ -167,11 +180,11 @@ c string d string --- !query 11 +-- !query 12 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 11 schema +-- !query 12 schema struct --- !query 11 output +-- !query 12 output a string b int c string @@ -186,19 +199,21 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 12 +-- !query 13 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 12 schema +-- !query 13 schema struct --- !query 12 output +-- !query 13 output a string b int c string @@ -213,39 +228,41 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 13 +-- !query 14 DESC t PARTITION (c='Us', d=2) --- !query 13 schema +-- !query 14 schema struct<> --- !query 13 output +-- !query 14 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 14 +-- !query 15 DESC t PARTITION (c='Us') --- !query 14 schema +-- !query 15 schema struct<> --- !query 14 output +-- !query 15 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 15 +-- !query 16 DESC t PARTITION (c='Us', d) --- !query 15 schema +-- !query 16 schema struct<> --- !query 15 output +-- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -255,19 +272,8 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 16 -DESC temp_v --- !query 16 schema -struct --- !query 16 output -a string -b int -c string -d string - - -- !query 17 -DESC TABLE temp_v +DESC temp_v -- !query 17 schema struct -- !query 17 output @@ -278,7 +284,7 @@ d string -- !query 18 -DESC FORMATTED temp_v +DESC TABLE temp_v -- !query 18 schema struct -- !query 18 output @@ -289,7 +295,7 @@ d string -- !query 19 -DESC EXTENDED temp_v +DESC FORMATTED temp_v -- !query 19 schema struct -- !query 19 output @@ -300,10 +306,21 @@ d string -- !query 20 -DESC temp_Data_Source_View +DESC EXTENDED temp_v -- !query 20 schema struct -- !query 20 output +a string +b int +c string +d string + + +-- !query 21 +DESC temp_Data_Source_View +-- !query 21 schema +struct +-- !query 21 output intType int test comment test1 stringType string dateType date @@ -322,42 +339,42 @@ arrayType array structType struct --- !query 21 +-- !query 22 DESC temp_v PARTITION (c='Us', d=1) --- !query 21 schema +-- !query 22 schema struct<> --- !query 21 output +-- !query 22 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a temporary view: temp_v; --- !query 22 +-- !query 23 DESC v --- !query 22 schema +-- !query 23 schema struct --- !query 22 output +-- !query 23 output a string b int c string d string --- !query 23 +-- !query 24 DESC TABLE v --- !query 23 schema +-- !query 24 schema struct --- !query 23 output +-- !query 24 output a string b int c string d string --- !query 24 +-- !query 25 DESC FORMATTED v --- !query 24 schema +-- !query 25 schema struct --- !query 24 output +-- !query 25 output a string b int c string @@ -372,14 +389,14 @@ Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 25 +-- !query 26 DESC EXTENDED v --- !query 25 schema +-- !query 26 schema struct --- !query 25 output +-- !query 26 output a string b int c string @@ -394,28 +411,20 @@ Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] - - --- !query 26 -DESC v PARTITION (c='Us', d=1) --- !query 26 schema -struct<> --- !query 26 output -org.apache.spark.sql.AnalysisException -DESC PARTITION is not allowed on a view: v; +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] -- !query 27 -DROP TABLE t +DESC v PARTITION (c='Us', d=1) -- !query 27 schema struct<> -- !query 27 output - +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a view: v; -- !query 28 -DROP VIEW temp_v +DROP TABLE t -- !query 28 schema struct<> -- !query 28 output @@ -423,7 +432,7 @@ struct<> -- !query 29 -DROP VIEW temp_Data_Source_View +DROP VIEW temp_v -- !query 29 schema struct<> -- !query 29 output @@ -431,8 +440,16 @@ struct<> -- !query 30 -DROP VIEW v +DROP VIEW temp_Data_Source_View -- !query 30 schema struct<> -- !query 30 output + + +-- !query 31 +DROP VIEW v +-- !query 31 schema +struct<> +-- !query 31 output + diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 28cfb744193ec..51ccf764d952f 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 45 +-- Number of queries: 51 -- !query 0 @@ -351,24 +351,72 @@ struct -- !query 42 -select floor(0) +select ceil(0.01) -- !query 42 schema -struct +struct -- !query 42 output -0 +1 -- !query 43 -select floor(1) +select ceiling(-0.10) -- !query 43 schema -struct +struct -- !query 43 output -1 +0 -- !query 44 -select floor(1234567890123456) +select floor(0) -- !query 44 schema -struct +struct -- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output 1234567890123456 + + +-- !query 47 +select floor(0.01) +-- !query 47 schema +struct +-- !query 47 output +0 + + +-- !query 48 +select floor(-0.10) +-- !query 48 schema +struct +-- !query 48 output +-1 + + +-- !query 49 +select 1 > 0.00001 +-- !query 49 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 49 output +true + + +-- !query 50 +select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) +-- !query 50 schema +struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> +-- !query 50 output +1 NULL 0 NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 4093a7b9fc820..52eb554edf89e 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 8 -- !query 0 @@ -78,3 +78,11 @@ select length(uuid()), (uuid() <> uuid()) struct -- !query 6 output 36 true + + +-- !query 7 +select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) +-- !query 7 schema +struct +-- !query 7 output +4 NULL NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 7b495656b93d7..45afbd29d1907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -191,6 +191,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) } } + + test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + val start = java.lang.Long.MAX_VALUE - 3 + val end = java.lang.Long.MIN_VALUE + 2 + Seq("false", "true").foreach { value => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) + } + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 7e2949ab5aece..4126660b5d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + package object packageobject { case class PackageClass(value: Int) } @@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) } + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 539c63d3cb288..6b98209fd49b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + // Null values + checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"), + Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil) + // Repeat generation at every input row checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a7efcafa0166a..68f61cfab6d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,12 +23,9 @@ import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.execution.{ScalarSubquery, SubqueryExec} import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -703,38 +700,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - test("Verify spark.sql.subquery.reuse") { - Seq(true, false).foreach { reuse => - withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) { - val df = sql( - """ - |SELECT key, (SELECT avg(key) FROM testData) - |FROM testData - |WHERE key > (SELECT avg(key) FROM testData) - |ORDER BY key - |LIMIT 3 - """.stripMargin) - - checkAnswer(df, Row(51, 50.5) :: Row(52, 50.5) :: Row(53, 50.5) :: Nil) - - val subqueries = ArrayBuffer.empty[SubqueryExec] - df.queryExecution.executedPlan.transformAllExpressions { - case s @ ScalarSubquery(plan: SubqueryExec, _) => - subqueries += plan - s - } - - assert(subqueries.size == 2, "Two ScalarSubquery are expected in the plan") - - if (reuse) { - assert(subqueries.distinct.size == 1, "Only one ScalarSubquery exists in the plan") - } else { - assert(subqueries.distinct.size == 2, "Reuse is not expected") - } - } - } - } - test("cartesian product join") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 6cf18de0cc768..50d8e3024598d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -111,8 +111,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 1024, // initial capacity, - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) assert(!map.iterator().next()) map.free() @@ -125,8 +124,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 1024, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val groupKey = InternalRow(UTF8String.fromString("cats")) @@ -152,8 +150,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val rand = new Random(42) val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet @@ -178,8 +175,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val keys = randomStrings(1024).take(512) @@ -226,8 +222,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) val sorter = map.destructAndCreateExternalSorter() @@ -267,8 +262,7 @@ class UnsafeFixedWidthAggregationMapSuite StructType(Nil), taskMemoryManager, 128, // initial capacity - PAGE_SIZE_BYTES, - false // disable perf metrics + PAGE_SIZE_BYTES ) (1 to 10).foreach { i => val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) @@ -312,8 +306,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - pageSize, - false // disable perf metrics + pageSize ) val rand = new Random(42) @@ -350,8 +343,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, taskMemoryManager, 128, // initial capacity - pageSize, - false // disable perf metrics + pageSize ) val rand = new Random(42) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 352dba79a4c08..89d9b69dec7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val cars = spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val exception = intercept[SparkException] { spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() } @@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read .option("mode", "abcd") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, @@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, @@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, @@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) .csv(testFile(valueMalformedFile)) .collect @@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read .option("header", true) - .option("wholeFile", true) + .option("multiLine", true) .csv(path.getAbsolutePath) // Check if headers have new lines in the names. @@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 65472cda9c1c0..704823ad516c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1814,7 +1814,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .option("compression", "gZiP") @@ -1836,7 +1836,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write.json(jsonDir) @@ -1865,7 +1865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) // no corrupt record column should be created assert(jsonDF.schema === StructType(Seq())) // only the first object should be read @@ -1886,7 +1886,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) @@ -1917,7 +1917,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path) checkAnswer(jsonDF, Seq(Row("test"))) } } @@ -1940,7 +1940,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .json(path) } @@ -1949,7 +1949,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val exceptionTwo = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .schema(schema) .json(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a4e62f1d16792..a12ce2b9eba34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.metric import java.io.File import scala.collection.mutable.HashMap +import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} @@ -35,18 +36,18 @@ import org.apache.spark.util.{AccumulatorContext, JsonProtocol} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ + /** - * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * Call `df.collect()` and collect necessary metrics from execution data. * * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run - * @param expectedMetrics the expected metrics. The format is - * `nodeId -> (operatorName, metric name -> metric value)`. + * @param expectedNodeIds the node ids of the metrics to collect from execution data. */ - private def testSparkPlanMetrics( + private def getSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + expectedNodeIds: Set[Long]): Option[Map[Long, (String, Map[String, Any])]] = { val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet withSQLConf("spark.sql.codegen.wholeStage" -> "false") { df.collect() @@ -63,9 +64,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => - expectedMetrics.contains(node.id) + expectedNodeIds.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => val metricValue = metricValues(metric.accumulatorId) @@ -73,7 +74,30 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { }.toMap (node.id, node.name -> nodeMetrics) }.toMap + Some(metrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + None + } + } + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) + optActualMetrics.map { actualMetrics => assert(expectedMetrics.keySet === actualMetrics.keySet) for (nodeId <- expectedMetrics.keySet) { val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) @@ -83,11 +107,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) } } - } else { - // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. - // Since we cannot track all jobs, the metric values could be wrong and we should not check - // them. - logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") } } @@ -130,19 +149,47 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // ... -> HashAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> HashAggregate(nodeId = 0) val df = testData2.groupBy().count() // 2 partitions + val expected1 = Seq( + Map("number of output rows" -> 2L, + "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + Map("number of output rows" -> 1L, + "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df, 1, Map( - 2L -> ("HashAggregate", Map("number of output rows" -> 2L)), - 0L -> ("HashAggregate", Map("number of output rows" -> 1L))) + 2L -> ("HashAggregate", expected1(0)), + 0L -> ("HashAggregate", expected1(1))) ) // 2 partitions and each partition contains 2 keys val df2 = testData2.groupBy('a).count() + val expected2 = Seq( + Map("number of output rows" -> 4L, + "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + Map("number of output rows" -> 3L, + "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df2, 1, Map( - 2L -> ("HashAggregate", Map("number of output rows" -> 4L)), - 0L -> ("HashAggregate", Map("number of output rows" -> 3L))) + 2L -> ("HashAggregate", expected2(0)), + 0L -> ("HashAggregate", expected2(1))) ) } + test("Aggregate metrics: track avg probe") { + val random = new Random() + val manyBytes = (0 until 65535).map { _ => + val byteArrSize = random.nextInt(100) + val bytes = new Array[Byte](byteArrSize) + random.nextBytes(bytes) + (bytes, random.nextInt(100)) + } + val df = manyBytes.toSeq.toDF("a", "b").repartition(1).groupBy('a).count() + val metrics = getSparkPlanMetrics(df, 1, Set(2L, 0L)).get + Seq(metrics(2L)._2("avg hashmap probe (min, med, max)"), + metrics(0L)._2("avg hashmap probe (min, med, max)")).foreach { probes => + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toInt > 1) + } + } + } + test("ObjectHashAggregate metrics") { // Assume the execution plan is // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..bdba536425a43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index f2456c7704064..135370bd1d677 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -37,6 +37,9 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "20") assert(conf.getConf(confEntry, 5) === 20) + conf.setConfString(key, " 20") + assert(conf.getConf(confEntry, 5) === 20) + val e = intercept[IllegalArgumentException] { conf.setConfString(key, "abc") } @@ -75,6 +78,8 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "true") assert(conf.getConf(confEntry, false) === true) + conf.setConfString(key, " true ") + assert(conf.getConf(confEntry, false) === true) val e = intercept[IllegalArgumentException] { conf.setConfString(key, "abc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d1..2a4039cc5831a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -172,8 +172,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { * * @param isFatalError if this is a fatal error. If so, the error should also be caught by * UncaughtExceptionHandler. + * @param assertFailure a function to verify the error. */ case class ExpectFailure[T <: Throwable : ClassTag]( + assertFailure: Throwable => Unit = _ => {}, isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] override def toString(): String = @@ -455,6 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") streamThreadDeathCause = null } + ef.assertFailure(exception.getCause) } catch { case _: InterruptedException => case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ff3784cab9e26..1d1074a2a7387 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation( return } else { setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 918459fe7c246..19453679a30df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -527,7 +527,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * Alter a table whose name that matches the one specified in `tableDefinition`, - * assuming the table exists. + * assuming the table exists. This method does not change the properties for data source and + * statistics. * * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. @@ -538,30 +539,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, tableDefinition.identifier.table) verifyTableProperties(tableDefinition) - // convert table statistics to properties so that we can persist them through hive api - val withStatsProps = if (tableDefinition.stats.isDefined) { - val stats = tableDefinition.stats.get - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) - } - } - tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) - } else { - tableDefinition - } - if (tableDefinition.tableType == VIEW) { - client.alterTable(withStatsProps) + client.alterTable(tableDefinition) } else { - val oldTableDef = getRawTable(db, withStatsProps.identifier.table) + val oldTableDef = getRawTable(db, tableDefinition.identifier.table) val newStorage = if (DDLUtils.isHiveTable(tableDefinition)) { tableDefinition.storage @@ -611,12 +592,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat TABLE_PARTITION_PROVIDER -> TABLE_PARTITION_PROVIDER_FILESYSTEM } - // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, - // to retain the spark specific format if it is. Also add old data source properties to table - // properties, to retain the data source table format. - val oldDataSourceProps = oldTableDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) - val newTableProps = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp - val newDef = withStatsProps.copy( + // Add old data source properties to table properties, to retain the data source table format. + // Add old stats properties to table properties, to retain spark's stats. + // Set the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, + // to retain the spark specific format if it is. + val propsFromOldTable = oldTableDef.properties.filter { case (k, v) => + k.startsWith(DATASOURCE_PREFIX) || k.startsWith(STATISTICS_PREFIX) + } + val newTableProps = propsFromOldTable ++ tableDefinition.properties + partitionProviderProp + val newDef = tableDefinition.copy( storage = newStorage, schema = oldTableDef.schema, partitionColumnNames = oldTableDef.partitionColumnNames, @@ -647,6 +631,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + + // convert table statistics to properties so that we can persist them through hive client + var statsProperties: Map[String, String] = + Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + if (stats.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } + } + + val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX)) + val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties) + client.alterTable(updatedTable) + } + override def getTable(db: String, table: String): CatalogTable = withClient { restoreTableMetadata(getRawTable(db, table)) } @@ -719,6 +729,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } + // Reorder table schema to put partition columns at the end. Before Spark 2.2, the partition + // columns are not put at the end of schema. We need to reorder it when reading the schema + // from the table properties. + private def reorderSchema(schema: StructType, partColumnNames: Seq[String]): StructType = { + val partitionFields = partColumnNames.map { partCol => + schema.find(_.name == partCol).getOrElse { + throw new AnalysisException("The metadata is corrupted. Unable to find the " + + s"partition column names from the schema. schema: ${schema.catalogString}. " + + s"Partition columns: ${partColumnNames.mkString("[", ", ", "]")}") + } + } + StructType(schema.filterNot(partitionFields.contains) ++ partitionFields) + } + private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { val hiveTable = table.copy( provider = Some(DDLUtils.HIVE_PROVIDER), @@ -728,10 +752,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema from table properties. if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { val schemaFromTableProps = getSchemaFromTableProperties(table) - if (DataType.equalsIgnoreCaseAndNullability(schemaFromTableProps, table.schema)) { + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + + if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema)) { hiveTable.copy( - schema = schemaFromTableProps, - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table)) } else { // Hive metastore may change the table schema, e.g. schema inference. If the table @@ -761,11 +788,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) + val schemaFromTableProps = getSchemaFromTableProperties(table) + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + table.copy( provider = Some(provider), storage = storageWithLocation, - schema = getSchemaFromTableProperties(table), - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index bd54c043c6ec4..d43534d5914d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -63,4 +63,30 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { assert(!rawTable.properties.contains(HiveExternalCatalog.DATASOURCE_PROVIDER)) assert(DDLUtils.isHiveTable(externalCatalog.getTable("db1", "hive_tbl"))) } + + Seq("parquet", "hive").foreach { format => + test(s"Partition columns should be put at the end of table schema for the format $format") { + val catalog = newBasicCatalog() + val newSchema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string") + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType() + .add("col1", "int") + .add("partCol1", "int") + .add("partCol2", "string") + .add("col2", "string"), + provider = Some(format), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val restoredTable = externalCatalog.getTable("db1", "tbl") + assert(restoredTable.schema == newSchema) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 5d52f8baa3b94..001bbc230ff18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,7 +25,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -267,7 +267,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("get statistics when not analyzed in both Hive and Spark") { + test("get statistics when not analyzed in Hive or Spark") { val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) @@ -313,60 +313,70 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("alter table SET TBLPROPERTIES after analyze table") { - Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" - withTable(tabName) { - createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('foo' = 'a')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) - - val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") - - val totalSize = extractStatsPropValues(describeResult, "totalSize") - assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + test("alter table should not have the side effect to store statistics in Spark side") { + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } - // ALTER TABLE SET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API - val numRows = extractStatsPropValues(describeResult, "numRows") - assert(numRows.isDefined && numRows.get == -1, "numRows is lost") - val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") - assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") - } + val table = "alter_table_side_effect" + withTable(table) { + sql(s"CREATE TABLE $table (i string, j string)") + sql(s"INSERT INTO TABLE $table SELECT 'a', 'b'") + val catalogTable1 = getCatalogTable(table) + val hiveSize1 = BigInt(catalogTable1.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + + sql(s"ALTER TABLE $table SET TBLPROPERTIES ('prop1' = 'a')") + + sql(s"INSERT INTO TABLE $table SELECT 'c', 'd'") + val catalogTable2 = getCatalogTable(table) + val hiveSize2 = BigInt(catalogTable2.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + // After insertion, Hive's stats should be changed. + assert(hiveSize2 > hiveSize1) + // We haven't generate stats in Spark, so we should still use Hive's stats here. + assert(catalogTable2.stats.get.sizeInBytes == hiveSize2) } } - test("alter table UNSET TBLPROPERTIES after analyze table") { + private def testAlterTableProperties(tabName: String, alterTablePropCmd: String): Unit = { Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName UNSET TBLPROPERTIES ('prop1')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + + // Run ALTER TABLE command + sql(alterTablePropCmd) val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") val totalSize = extractStatsPropValues(describeResult, "totalSize") assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") - // ALTER TABLE UNSET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API + // ALTER TABLE SET/UNSET TBLPROPERTIES invalidates some Hive specific statistics, but not + // Spark specific statistics. This is triggered by the Hive alterTable API. val numRows = extractStatsPropValues(describeResult, "numRows") assert(numRows.isDefined && numRows.get == -1, "numRows is lost") val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + + if (analyzedBySpark) { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } else { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) + } } } } + test("alter table SET TBLPROPERTIES after analyze table") { + testAlterTableProperties("set_prop_table", + "ALTER TABLE set_prop_table SET TBLPROPERTIES ('foo' = 'a')") + } + + test("alter table UNSET TBLPROPERTIES after analyze table") { + testAlterTableProperties("unset_prop_table", + "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index f818e29555468..d91f25a4da013 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") { + withTable("tbl") { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS") + val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats + assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost") + + val df = sql("SELECT * FROM tbl WHERE p = 1") + val sizes1 = df.queryExecution.analyzed.collect { + case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes + } + assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizes1(0) == tableStats.get.sizeInBytes) + + val relations = df.queryExecution.optimizedPlan.collect { + case relation: LogicalRelation => relation + } + assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") + val size2 = relations(0).computeStats(conf).sizeInBytes + assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) + assert(size2 < tableStats.get.sizeInBytes) + } + } } diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 2803cad8095dd..00c59728748f6 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -56,7 +56,7 @@ public abstract class WriteAheadLog { public abstract void clean(long threshTime, boolean waitForCompletion); /** - * Close this log and release any resources. + * Close this log and release any resources. It must be idempotent. */ public abstract void close(); } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index bd7ab0b9bf5eb..6f130c803f310 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -165,11 +165,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (isTrackerStarted) { - // First, stop the receivers - trackerState = Stopping + val isStarted: Boolean = isTrackerStarted + trackerState = Stopping + if (isStarted) { if (!skipReceiverLaunch) { - // Send the stop signal to all the receivers + // First, stop the receivers. Send the stop signal to all the receivers endpoint.askSync[Boolean](StopAllReceivers) // Wait for the Spark job that runs the receivers to be over @@ -194,17 +194,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped - } else if (isTrackerInitialized) { - trackerState = Stopping - // `ReceivedBlockTracker` is open when this instance is created. We should - // close this even if this `ReceiverTracker` is not started. - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped } + + // `ReceivedBlockTracker` is open when this instance is created. We should + // close this even if this `ReceiverTracker` is not started. + receivedBlockTracker.stop() + logInfo("ReceiverTracker stopped") + trackerState = Stopped } /** Allocate all unallocated blocks to the given batch. */ @@ -453,9 +449,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false endpoint.send(StartAllReceivers(receivers)) } - /** Check if tracker has been marked for initiated */ - private def isTrackerInitialized: Boolean = trackerState == Initialized - /** Check if tracker has been marked for starting */ private def isTrackerStarted: Boolean = trackerState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 35f0166ed0cf2..e522bc62d5cac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private val walWriteQueue = new LinkedBlockingQueue[Record]() // Whether the writer thread is active - @volatile private var active: Boolean = true + private val active: AtomicBoolean = new AtomicBoolean(true) private val buffer = new ArrayBuffer[Record]() private val batchedWriterThread = startBatchedWriterThread() @@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { val promise = Promise[WriteAheadLogRecordHandle]() val putSuccessfully = synchronized { - if (active) { + if (active.get()) { walWriteQueue.offer(Record(byteBuffer, time, promise)) true } else { @@ -121,9 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp */ override def close(): Unit = { logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") - synchronized { - active = false - } + if (!active.getAndSet(false)) return batchedWriterThread.interrupt() batchedWriterThread.join() while (!walWriteQueue.isEmpty) { @@ -138,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private def startBatchedWriterThread(): Thread = { val thread = new Thread(new Runnable { override def run(): Unit = { - while (active) { + while (active.get()) { try { flushRecords() } catch { @@ -166,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp } try { var segment: WriteAheadLogRecordHandle = null - if (buffer.length > 0) { + if (buffer.nonEmpty) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") // threads may not be able to add items in order by time val sortedByTime = buffer.sortBy(_.time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 1e5f18797e152..d6e15cfdd2723 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -205,10 +205,12 @@ private[streaming] class FileBasedWriteAheadLog( /** Stop the manager, close any open log writer */ def close(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() + if (!executionContext.isShutdown) { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() } - executionContext.shutdown() logInfo("Stopped write ahead log manager") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index df122ac090c3e..c206d3169d77e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -57,6 +57,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } finally { tracker.stop(false) + // Make sure it is idempotent. + tracker.stop(false) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4bec52b9fe4fe..ede15399f0e2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -140,6 +140,8 @@ abstract class CommonWriteAheadLogTests( } } writeAheadLog.close() + // Make sure it is idempotent. + writeAheadLog.close() } test(testPrefix + "handling file errors while reading rotating logs") {