diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index f12f8c275a989..18b2db69db8f1 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -6,3 +6,4 @@ ^README\.Rmd$ ^src-native$ ^html$ +^tests/fulltests/* diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 166b39813c14e..3b9d42d6e7158 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2646,6 +2646,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. +#' Also as standard in SQL, this function resolves columns by position (not by name). #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 4ca7aa664e023..ec931befa2854 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -267,7 +267,7 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context sparkCachePath <- function() { - if (.Platform$OS.type == "windows") { + if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { stop(paste("%LOCALAPPDATA% not found.", diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea45e394500e8..91483a4d23d9b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -908,10 +908,6 @@ isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } -is_cran <- function() { - !identical(Sys.getenv("NOT_CRAN"), "true") -} - is_windows <- function() { .Platform$OS.type == "windows" } @@ -920,6 +916,6 @@ hadoop_home_set <- function() { !identical(Sys.getenv("HADOOP_HOME"), "") } -not_cran_or_windows_with_hadoop <- function() { - !is_cran() && (!is_windows() || hadoop_home_set()) +windows_with_hadoop <- function() { + !is_windows() || hadoop_home_set() } diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R new file mode 100644 index 0000000000000..de47162d5325f --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic tests for CRAN") + +test_that("create DataFrame from list or data.frame", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + i <- 4 + df <- createDataFrame(data.frame(dummy = 1:i)) + expect_equal(count(df), i) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + mtcarsdf <- createDataFrame(mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) + + sparkR.session.stop() +}) + +test_that("spark.glm and predict", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + sparkR.session.stop() +}) diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/jarTest.R rename to R/pkg/tests/fulltests/jarTest.R diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/packageInAJarTest.R rename to R/pkg/tests/fulltests/packageInAJarTest.R diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_Serde.R rename to R/pkg/tests/fulltests/test_Serde.R index 6e160fae1afed..6bbd201bf1d82 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -20,8 +20,6 @@ context("SerDe functionality") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { - skip_on_cran() - x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -40,8 +38,6 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { - skip_on_cran() - x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -69,8 +65,6 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { - skip_on_cran() - x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R similarity index 85% rename from R/pkg/inst/tests/testthat/test_Windows.R rename to R/pkg/tests/fulltests/test_Windows.R index 00d684e1a49ef..b2ec6c67311db 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -17,9 +17,7 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - skip_on_cran() - - if (.Platform$OS.type != "windows") { + if (!is_windows()) { skip("This test is only for Windows, skipped") } @@ -27,6 +25,3 @@ test_that("sparkJars tag in SparkContext", { abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) - -message("--- End test (Windows) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binaryFile.R rename to R/pkg/tests/fulltests/test_binaryFile.R index 00954fa31b0ee..758b174b8787c 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/tests/fulltests/test_binaryFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -40,8 +38,6 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -54,8 +50,6 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -80,8 +74,6 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_binary_function.R rename to R/pkg/tests/fulltests/test_binary_function.R index 236cb3885445e..442bed509bb1d 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -29,8 +29,6 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { - skip_on_cran() - actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -53,8 +51,6 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -73,8 +69,6 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_broadcast.R rename to R/pkg/tests/fulltests/test_broadcast.R index 2c96740df77bb..fc2c7c2deb825 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/tests/fulltests/test_broadcast.R @@ -26,8 +26,6 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcastRDD(sc, randomMat) @@ -40,8 +38,6 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_client.R rename to R/pkg/tests/fulltests/test_client.R index 3d53bebab6300..0cf25fe1dbf39 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/tests/fulltests/test_client.R @@ -18,8 +18,6 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -28,22 +26,16 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { - skip_on_cran() - expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_context.R rename to R/pkg/tests/fulltests/test_context.R index f6d9f5423df02..710485d56685a 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -18,8 +18,6 @@ context("test functions in sparkR.R") test_that("Check masked functions", { - skip_on_cran() - # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -57,8 +55,6 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { - skip_on_cran() - for (i in 1:4) { sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) @@ -77,8 +73,6 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -102,8 +96,6 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -116,16 +108,12 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { - skip_on_cran() - e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -153,8 +141,6 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { - skip_on_cran() - expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -182,8 +168,6 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_includePackage.R rename to R/pkg/tests/fulltests/test_includePackage.R index d7d9eeed1575e..f4ea0d1b5cb27 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -26,8 +26,6 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -44,8 +42,6 @@ test_that("include inside function", { }) test_that("use include package", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_jvm_api.R rename to R/pkg/tests/fulltests/test_jvm_api.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_classification.R rename to R/pkg/tests/fulltests/test_mllib_classification.R index 82e588dc460d0..726e9d9a20b1c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.svmLinear", { - skip_on_cran() - df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) @@ -51,7 +49,7 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -131,7 +129,7 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -228,8 +226,6 @@ test_that("spark.logit", { }) test_that("spark.mlp", { - skip_on_cran() - df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), @@ -250,7 +246,7 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -363,7 +359,7 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") write.ml(m, modelPath) expect_error(write.ml(m, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R rename to R/pkg/tests/fulltests/test_mllib_clustering.R index e827e961ab4c1..4110e13da4948 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.bisectingKmeans", { - skip_on_cran() - newIris <- iris newIris$Species <- NULL training <- suppressWarnings(createDataFrame(newIris)) @@ -55,7 +53,7 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -129,7 +127,7 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -177,7 +175,7 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -244,7 +242,7 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -265,8 +263,6 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -309,8 +305,6 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R rename to R/pkg/tests/fulltests/test_mllib_fpm.R index 4e10ca1e4f50b..69dda52f0c279 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -62,7 +62,7 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") write.ml(model, modelPath, overwrite = TRUE) loaded_model <- read.ml(modelPath) diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R rename to R/pkg/tests/fulltests/test_mllib_recommendation.R index cc8064f88d27a..4d919c9d746b0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R @@ -37,7 +37,7 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_mllib_regression.R rename to R/pkg/tests/fulltests/test_mllib_regression.R index b05fdd350ca28..82472c92b9965 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -23,8 +23,6 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -197,8 +195,6 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -226,8 +222,6 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -254,8 +248,6 @@ test_that("formula of glm", { }) test_that("glm and predict", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -300,8 +292,6 @@ test_that("glm and predict", { }) test_that("glm summary", { - skip_on_cran() - # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -351,8 +341,6 @@ test_that("glm summary", { }) test_that("glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -401,7 +389,7 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -452,7 +440,7 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_stat.R rename to R/pkg/tests/fulltests/test_mllib_stat.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_mllib_tree.R rename to R/pkg/tests/fulltests/test_mllib_tree.R index 31427ee52a5e9..9b3fc8d270b25 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.gbt", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) @@ -46,7 +44,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -80,7 +78,7 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -105,7 +103,7 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), source = "libsvm") model <- spark.gbt(data, label ~ features, "classification") @@ -144,7 +142,7 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -178,7 +176,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -215,7 +213,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.randomForest(data, label ~ features, "classification") @@ -224,8 +222,6 @@ test_that("spark.randomForest", { }) test_that("spark.decisionTree", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) @@ -242,7 +238,7 @@ test_that("spark.decisionTree", { expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -273,7 +269,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -309,7 +305,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("2.0", predictions)), 50) # spark.decisionTree classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.decisionTree(data, label ~ features, "classification") diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R rename to R/pkg/tests/fulltests/test_parallelize_collect.R index 52d4c93ed9599..3d122ccaf448f 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/tests/fulltests/test_parallelize_collect.R @@ -39,8 +39,6 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -68,8 +66,6 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -90,8 +86,6 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { - skip_on_cran() - # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -101,8 +95,6 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { - skip_on_cran() - # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_rdd.R rename to R/pkg/tests/fulltests/test_rdd.R index fb244e1d49e20..6ee1fceffd822 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -29,30 +29,22 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - skip_on_cran() - expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { - skip_on_cran() - expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - skip_on_cran() - expect_equal(countRDD(rdd), 10) expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { - skip_on_cran() - mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -64,40 +56,30 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { - skip_on_cran() - multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { - skip_on_cran() - sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { - skip_on_cran() - sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { - skip_on_cran() - flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { - skip_on_cran() - filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -113,8 +95,6 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { - skip_on_cran() - vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -123,8 +103,6 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { - skip_on_cran() - rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -139,8 +117,6 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { - skip_on_cran() - # RDD rdd2 <- rdd # PipelinedRDD @@ -182,8 +158,6 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { - skip_on_cran() - sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -193,8 +167,6 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { - skip_on_cran() - fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -203,8 +175,6 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { - skip_on_cran() - func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -221,14 +191,10 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { - skip_on_cran() - expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { - skip_on_cran() - # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -271,8 +237,6 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { - skip_on_cran() - multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -282,8 +246,6 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { - skip_on_cran() - l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -296,8 +258,6 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { - skip_on_cran() - pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -311,8 +271,6 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { - skip_on_cran() - nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -321,29 +279,21 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { - skip_on_cran() - max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { - skip_on_cran() - min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { - skip_on_cran() - sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { - skip_on_cran() - func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -351,8 +301,6 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -374,8 +322,6 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { - skip_on_cran() - sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -387,8 +333,6 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -401,8 +345,6 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -415,8 +357,6 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { - skip_on_cran() - actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -426,8 +366,6 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -441,8 +379,6 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -457,8 +393,6 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -473,32 +407,24 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { - skip_on_cran() - rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { - skip_on_cran() - keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { - skip_on_cran() - values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { - skip_on_cran() - actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -516,8 +442,6 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -547,8 +471,6 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -592,8 +514,6 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { - skip_on_cran() - l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -621,8 +541,6 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { - skip_on_cran() - l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -652,8 +570,6 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { - skip_on_cran() - # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -670,8 +586,6 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -696,8 +610,6 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -728,8 +640,6 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -757,8 +667,6 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -790,8 +698,6 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - skip_on_cran() - numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -841,8 +747,6 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { - skip_on_cran() - rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -861,15 +765,11 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - skip_on_cran() - rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -894,8 +794,6 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { - skip_on_cran() - rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_shuffle.R rename to R/pkg/tests/fulltests/test_shuffle.R index 18320ea44b389..98300c67c415f 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/tests/fulltests/test_shuffle.R @@ -37,8 +37,6 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { - skip_on_cran() - grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -48,8 +46,6 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { - skip_on_cran() - grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -59,8 +55,6 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { - skip_on_cran() - reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -70,8 +64,6 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { - skip_on_cran() - reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -80,8 +72,6 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { - skip_on_cran() - reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -91,8 +81,6 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { - skip_on_cran() - reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -101,8 +89,6 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { - skip_on_cran() - stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -115,8 +101,6 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { - skip_on_cran() - # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -145,8 +129,6 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { - skip_on_cran() - # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -190,8 +172,6 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { - skip_on_cran() - # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -207,8 +187,6 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { - skip_on_cran() - kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -227,8 +205,6 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { - skip_on_cran() - words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_sparkR.R rename to R/pkg/tests/fulltests/test_sparkR.R index a40981c188f7a..f73fc6baeccef 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/tests/fulltests/test_sparkR.R @@ -18,8 +18,6 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { - skip_on_cran() - # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_sparkSQL.R rename to R/pkg/tests/fulltests/test_sparkSQL.R index c790d02b107be..af529067f43e0 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -61,7 +61,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- if (not_cran_or_windows_with_hadoop()) { +sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster) } else { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) @@ -100,26 +100,20 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) -if (.Platform$OS.type == "windows") { +if (is_windows()) { Sys.setenv(TZ = "GMT") } test_that("calling sparkRSQL.init returns existing SQL context", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { - skip_on_cran() - expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { - skip_on_cran() - expect_equal(sparkR.session(), sparkSession) }) @@ -217,8 +211,6 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -316,8 +308,6 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { - skip_on_cran() - # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -330,7 +320,7 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -380,8 +370,6 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { - skip_on_cran() - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -436,8 +424,6 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -549,8 +535,6 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - skip_on_cran() - ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -563,8 +547,6 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { - skip_on_cran() - # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -607,7 +589,7 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { # Test read.df df <- read.df(jsonPath, "json") expect_is(df, "SparkDataFrame") @@ -654,8 +636,6 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -669,8 +649,6 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -730,8 +708,6 @@ test_that( }) test_that("test cache, uncache and clearCache", { - skip_on_cran() - df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -744,7 +720,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") dfParquet <- read.df(parquetPath, "parquet") @@ -787,8 +763,6 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -796,8 +770,6 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -808,8 +780,6 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { - skip_on_cran() - # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -839,8 +809,6 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { - skip_on_cran() - objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -853,8 +821,6 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -923,8 +889,6 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - skip_on_cran() - df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -964,7 +928,7 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { checkpointDir <- file.path(tempdir(), "cproot") expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) @@ -1341,7 +1305,7 @@ test_that("column calculation", { }) test_that("test HiveContext", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { setHiveContext(sc) schema <- structType(structField("name", "string"), structField("age", "integer"), @@ -1395,8 +1359,6 @@ test_that("column operators", { }) test_that("column functions", { - skip_on_cran() - c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) @@ -1782,8 +1744,6 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { }) test_that("group by, agg functions", { - skip_on_cran() - df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -2125,8 +2085,6 @@ test_that("filter() on a DataFrame", { }) test_that("join(), crossJoin() and merge() on a DataFrame", { - skip_on_cran() - df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -2400,8 +2358,6 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2423,8 +2379,6 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2440,7 +2394,7 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") # Test write.df and read.df write.df(df, parquetPath, "parquet", mode = "overwrite") @@ -2473,8 +2427,6 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { - skip_on_cran() - df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2492,8 +2444,6 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { - skip_on_cran() - # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2515,8 +2465,6 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2750,8 +2698,6 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { - skip_on_cran() - retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2760,8 +2706,6 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { - skip_on_cran() - expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2984,8 +2928,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { - skip_on_cran() - df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -3006,8 +2948,6 @@ test_that("dapplyCollect() on DataFrame with a binary column", { }) test_that("repartition by columns on DataFrame", { - skip_on_cran() - df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3046,8 +2986,6 @@ test_that("repartition by columns on DataFrame", { }) test_that("coalesce, repartition, numPartitions", { - skip_on_cran() - df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) expect_equal(getNumPartitions(coalesce(df, 3)), 3) @@ -3067,8 +3005,6 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - skip_on_cran() - df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3186,8 +3122,6 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -3221,8 +3155,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { }) test_that("randomSplit", { - skip_on_cran() - num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) weights <- c(2, 3, 5) @@ -3269,8 +3201,6 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { - skip_on_cran() - setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -3286,8 +3216,6 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { - skip_on_cran() - df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -3314,8 +3242,6 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { - skip_on_cran() - # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3440,8 +3366,6 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory - # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_streaming.R rename to R/pkg/tests/fulltests/test_streaming.R index b20b4312fbaae..d691de7cd725d 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -24,7 +24,7 @@ context("Structured Streaming") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") -if (.Platform$OS.type == "windows") { +if (is_windows()) { # file.path removes the empty separator on Windows, adds it back jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) } @@ -47,8 +47,6 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -69,8 +67,6 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { }) test_that("print from explain, lastProgress, status, isActive", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -90,8 +86,6 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { - skip_on_cran() - parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -118,8 +112,6 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { - skip_on_cran() - c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -129,8 +121,6 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { - skip_on_cran() - # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -139,8 +129,6 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_take.R rename to R/pkg/tests/fulltests/test_take.R index c00723ba31f4c..8936cc57da227 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/tests/fulltests/test_take.R @@ -34,8 +34,6 @@ sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FA sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - skip_on_cran() - numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_textFile.R rename to R/pkg/tests/fulltests/test_textFile.R index e8a961cb3e870..be2d2711ff88e 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/tests/fulltests/test_textFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -38,8 +36,6 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -50,8 +46,6 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -70,8 +64,6 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -86,8 +78,6 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -102,8 +92,6 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -115,8 +103,6 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -142,8 +128,6 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -157,8 +141,6 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_utils.R rename to R/pkg/tests/fulltests/test_utils.R index 6197ae7569879..af81423aa8dd0 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -23,7 +23,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { - skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -41,7 +40,6 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { - skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -169,8 +167,6 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - skip_on_cran() - method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "col", "unknown", TRUE), @@ -181,8 +177,6 @@ test_that("captureJVMException", { }) test_that("hashCode", { - skip_on_cran() - expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) @@ -243,6 +237,3 @@ test_that("basenameSansExtFromUrl", { }) sparkR.session.stop() - -message("--- End test (utils) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index f0bef4f6d2662..f00a610679752 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -24,8 +24,6 @@ options("warn" = 2) if (.Platform$OS.type == "windows") { Sys.setenv(TZ = "GMT") } -message("--- Start test ", as.POSIXct(Sys.time(), tz = "GMT")) -timer_ptm <- proc.time() # Setup global test environment # Install Spark first to set SPARK_HOME @@ -43,3 +41,11 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { } test_package("SparkR") + +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") +} diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml new file mode 100644 index 0000000000000..d00cf2788b964 --- /dev/null +++ b/common/kvstore/pom.xml @@ -0,0 +1,101 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../pom.xml + + + spark-kvstore_2.11 + jar + Spark Project Local DB + http://spark.apache.org/ + + kvstore + + + + + com.google.guava + guava + + + org.fusesource.leveldbjni + leveldbjni-all + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + commons-io + commons-io + test + + + log4j + log4j + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + io.dropwizard.metrics + metrics-core + test + + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java new file mode 100644 index 0000000000000..8b8899023c938 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Tags a field to be indexed when storing an object. + * + *

+ * Types are required to have a natural index that uniquely identifies instances in the store. + * The default value of the annotation identifies the natural index for the type. + *

+ * + *

+ * Indexes allow for more efficient sorting of data read from the store. By annotating a field or + * "getter" method with this annotation, an index will be created that will provide sorting based on + * the string value of that field. + *

+ * + *

+ * Note that creating indices means more space will be needed, and maintenance operations like + * updating or deleting a value will become more expensive. + *

+ * + *

+ * Indices are restricted to String, integral types (byte, short, int, long, boolean), and arrays + * of those values. + *

+ */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.METHOD}) +public @interface KVIndex { + + public static final String NATURAL_INDEX_NAME = "__main__"; + + /** + * The name of the index to be created for the annotated entity. Must be unique within + * the class. Index names are not allowed to start with an underscore (that's reserved for + * internal use). The default value is the natural index name (which is always a copy index + * regardless of the annotation's values). + */ + String value() default NATURAL_INDEX_NAME; + + /** + * The name of the parent index of this index. By default there is no parent index, so the + * generated data can be retrieved without having to provide a parent value. + * + *

+ * If a parent index is defined, iterating over the data using the index will require providing + * a single value for the parent index. This serves as a rudimentary way to provide relationships + * between entities in the store. + *

+ */ + String parent() default ""; + + /** + * Whether to copy the instance's data to the index, instead of just storing a pointer to the + * data. The default behavior is to just store a reference; that saves disk space but is slower + * to read, since there's a level of indirection. + */ + boolean copy() default false; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java new file mode 100644 index 0000000000000..3be4b829b4d8d --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * Abstraction for a local key/value store for storing app data. + * + *

+ * There are two main features provided by the implementations of this interface: + *

+ * + *

Serialization

+ * + *

+ * If the underlying data store requires serialization, data will be serialized to and deserialized + * using a {@link KVStoreSerializer}, which can be customized by the application. The serializer is + * based on Jackson, so it supports all the Jackson annotations for controlling the serialization of + * app-defined types. + *

+ * + *

+ * Data is also automatically compressed to save disk space. + *

+ * + *

Automatic Key Management

+ * + *

+ * When using the built-in key management, the implementation will automatically create unique + * keys for each type written to the store. Keys are based on the type name, and always start + * with the "+" prefix character (so that it's easy to use both manual and automatic key + * management APIs without conflicts). + *

+ * + *

+ * Another feature of automatic key management is indexing; by annotating fields or methods of + * objects written to the store with {@link KVIndex}, indices are created to sort the data + * by the values of those properties. This makes it possible to provide sorting without having + * to load all instances of those types from the store. + *

+ * + *

+ * KVStore instances are thread-safe for both reads and writes. + *

+ */ +public interface KVStore extends Closeable { + + /** + * Returns app-specific metadata from the store, or null if it's not currently set. + * + *

+ * The metadata type is application-specific. This is a convenience method so that applications + * don't need to define their own keys for this information. + *

+ */ + T getMetadata(Class klass) throws Exception; + + /** + * Writes the given value in the store metadata key. + */ + void setMetadata(Object value) throws Exception; + + /** + * Read a specific instance of an object. + * + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws NoSuchElementException If an element with the given key does not exist. + */ + T read(Class klass, Object naturalKey) throws Exception; + + /** + * Writes the given object to the store, including indexed fields. Indices are updated based + * on the annotated fields of the object's class. + * + *

+ * Writes may be slower when the object already exists in the store, since it will involve + * updating existing indices. + *

+ * + * @param value The object to write. + */ + void write(Object value) throws Exception; + + /** + * Removes an object and all data related to it, like index entries, from the store. + * + * @param type The object's type. + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws NoSuchElementException If an element with the given key does not exist. + */ + void delete(Class type, Object naturalKey) throws Exception; + + /** + * Returns a configurable view for iterating over entities of the given type. + */ + KVStoreView view(Class type) throws Exception; + + /** + * Returns the number of items of the given type currently in the store. + */ + long count(Class type) throws Exception; + + /** + * Returns the number of items of the given type which match the given indexed value. + */ + long count(Class type, String index, Object indexedValue) throws Exception; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java new file mode 100644 index 0000000000000..3efdec9ed32be --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.List; + +/** + * An iterator for KVStore. + * + *

+ * Iterators may keep references to resources that need to be closed. It's recommended that users + * explicitly close iterators after they're used. + *

+ */ +public interface KVStoreIterator extends Iterator, AutoCloseable { + + /** + * Retrieve multiple elements from the store. + * + * @param max Maximum number of elements to retrieve. + */ + List next(int max); + + /** + * Skip in the iterator. + * + * @return Whether there are items left after skipping. + */ + boolean skip(long n); + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java new file mode 100644 index 0000000000000..b84ec91cf67a0 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Serializer used to translate between app-defined types and the LevelDB store. + * + *

+ * The serializer is based on Jackson, so values are written as JSON. It also allows "naked strings" + * and integers to be written as values directly, which will be written as UTF-8 strings. + *

+ */ +public class KVStoreSerializer { + + /** + * Object mapper used to process app-specific types. If an application requires a specific + * configuration of the mapper, it can subclass this serializer and add custom configuration + * to this object. + */ + protected final ObjectMapper mapper; + + public KVStoreSerializer() { + this.mapper = new ObjectMapper(); + } + + public final byte[] serialize(Object o) throws Exception { + if (o instanceof String) { + return ((String) o).getBytes(UTF_8); + } else { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + GZIPOutputStream out = new GZIPOutputStream(bytes); + try { + mapper.writeValue(out, o); + } finally { + out.close(); + } + return bytes.toByteArray(); + } + } + + @SuppressWarnings("unchecked") + public final T deserialize(byte[] data, Class klass) throws Exception { + if (klass.equals(String.class)) { + return (T) new String(data, UTF_8); + } else { + GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); + try { + return mapper.readValue(in, klass); + } finally { + in.close(); + } + } + } + + final byte[] serialize(long value) { + return String.valueOf(value).getBytes(UTF_8); + } + + final long deserializeLong(byte[] data) { + return Long.parseLong(new String(data, UTF_8)); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java new file mode 100644 index 0000000000000..b761640e6da8b --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.Map; + +import com.google.common.base.Preconditions; + +/** + * A configurable view that allows iterating over values in a {@link KVStore}. + * + *

+ * The different methods can be used to configure the behavior of the iterator. Calling the same + * method multiple times is allowed; the most recent value will be used. + *

+ * + *

+ * The iterators returned by this view are of type {@link KVStoreIterator}; they auto-close + * when used in a for loop that exhausts their contents, but when used manually, they need + * to be closed explicitly unless all elements are read. + *

+ */ +public abstract class KVStoreView implements Iterable { + + final Class type; + + boolean ascending = true; + String index = KVIndex.NATURAL_INDEX_NAME; + Object first = null; + Object last = null; + Object parent = null; + long skip = 0L; + long max = Long.MAX_VALUE; + + public KVStoreView(Class type) { + this.type = type; + } + + /** + * Reverses the order of iteration. By default, iterates in ascending order. + */ + public KVStoreView reverse() { + ascending = !ascending; + return this; + } + + /** + * Iterates according to the given index. + */ + public KVStoreView index(String name) { + this.index = Preconditions.checkNotNull(name); + return this; + } + + /** + * Defines the value of the parent index when iterating over a child index. Only elements that + * match the parent index's value will be included in the iteration. + * + *

+ * Required for iterating over child indices, will generate an error if iterating over a + * parent-less index. + *

+ */ + public KVStoreView parent(Object value) { + this.parent = value; + return this; + } + + /** + * Iterates starting at the given value of the chosen index (inclusive). + */ + public KVStoreView first(Object value) { + this.first = value; + return this; + } + + /** + * Stops iteration at the given value of the chosen index (inclusive). + */ + public KVStoreView last(Object value) { + this.last = value; + return this; + } + + /** + * Stops iteration after a number of elements has been retrieved. + */ + public KVStoreView max(long max) { + Preconditions.checkArgument(max > 0L, "max must be positive."); + this.max = max; + return this; + } + + /** + * Skips a number of elements at the start of iteration. Skipped elements are not accounted + * when using {@link #max(long)}. + */ + public KVStoreView skip(long n) { + this.skip = n; + return this; + } + + /** + * Returns an iterator for the current configuration. + */ + public KVStoreIterator closeableIterator() throws Exception { + return (KVStoreIterator) iterator(); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java new file mode 100644 index 0000000000000..90f2ff0079b8a --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.base.Preconditions; + +/** + * Wrapper around types managed in a KVStore, providing easy access to their indexed fields. + */ +public class KVTypeInfo { + + private final Class type; + private final Map indices; + private final Map accessors; + + public KVTypeInfo(Class type) throws Exception { + this.type = type; + this.accessors = new HashMap<>(); + this.indices = new HashMap<>(); + + for (Field f : type.getDeclaredFields()) { + KVIndex idx = f.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + indices.put(idx.value(), idx); + f.setAccessible(true); + accessors.put(idx.value(), new FieldAccessor(f)); + } + } + + for (Method m : type.getDeclaredMethods()) { + KVIndex idx = m.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + Preconditions.checkArgument(m.getParameterTypes().length == 0, + "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + indices.put(idx.value(), idx); + m.setAccessible(true); + accessors.put(idx.value(), new MethodAccessor(m)); + } + } + + Preconditions.checkArgument(indices.containsKey(KVIndex.NATURAL_INDEX_NAME), + "No natural index defined for type %s.", type.getName()); + Preconditions.checkArgument(indices.get(KVIndex.NATURAL_INDEX_NAME).parent().isEmpty(), + "Natural index of %s cannot have a parent.", type.getName()); + + for (KVIndex idx : indices.values()) { + if (!idx.parent().isEmpty()) { + KVIndex parent = indices.get(idx.parent()); + Preconditions.checkArgument(parent != null, + "Cannot find parent %s of index %s.", idx.parent(), idx.value()); + Preconditions.checkArgument(parent.parent().isEmpty(), + "Parent index %s of index %s cannot be itself a child index.", idx.parent(), idx.value()); + } + } + } + + private void checkIndex(KVIndex idx, Map indices) { + Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), + "No name provided for index in type %s.", type.getName()); + Preconditions.checkArgument( + !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), + "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); + Preconditions.checkArgument(idx.parent().isEmpty() || !idx.parent().equals(idx.value()), + "Index %s cannot be parent of itself.", idx.value()); + Preconditions.checkArgument(!indices.containsKey(idx.value()), + "Duplicate index %s for type %s.", idx.value(), type.getName()); + } + + public Class getType() { + return type; + } + + public Object getIndexValue(String indexName, Object instance) throws Exception { + return getAccessor(indexName).get(instance); + } + + public Stream indices() { + return indices.values().stream(); + } + + Accessor getAccessor(String indexName) { + Accessor a = accessors.get(indexName); + Preconditions.checkArgument(a != null, "No index %s.", indexName); + return a; + } + + Accessor getParentAccessor(String indexName) { + KVIndex index = indices.get(indexName); + return index.parent().isEmpty() ? null : getAccessor(index.parent()); + } + + /** + * Abstracts the difference between invoking a Field and a Method. + */ + interface Accessor { + + Object get(Object instance) throws Exception; + + } + + private class FieldAccessor implements Accessor { + + private final Field field; + + FieldAccessor(Field field) { + this.field = field; + } + + @Override + public Object get(Object instance) throws Exception { + return field.get(instance); + } + + } + + private class MethodAccessor implements Accessor { + + private final Method method; + + MethodAccessor(Method method) { + this.method = method; + } + + @Override + public Object get(Object instance) throws Exception { + return method.invoke(instance); + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java new file mode 100644 index 0000000000000..08b22fd8265d8 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.fusesource.leveldbjni.JniDBFactory; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.iq80.leveldb.WriteBatch; + +/** + * Implementation of KVStore that uses LevelDB as the underlying data store. + */ +public class LevelDB implements KVStore { + + @VisibleForTesting + static final long STORE_VERSION = 1L; + + @VisibleForTesting + static final byte[] STORE_VERSION_KEY = "__version__".getBytes(UTF_8); + + /** DB key where app metadata is stored. */ + private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); + + /** DB key where type aliases are stored. */ + private static final byte[] TYPE_ALIASES_KEY = "__types__".getBytes(UTF_8); + + final AtomicReference _db; + final KVStoreSerializer serializer; + + /** + * Keep a mapping of class names to a shorter, unique ID managed by the store. This serves two + * purposes: make the keys stored on disk shorter, and spread out the keys, since class names + * will often have a long, redundant prefix (think "org.apache.spark."). + */ + private final ConcurrentMap typeAliases; + private final ConcurrentMap, LevelDBTypeInfo> types; + + public LevelDB(File path) throws Exception { + this(path, new KVStoreSerializer()); + } + + public LevelDB(File path, KVStoreSerializer serializer) throws Exception { + this.serializer = serializer; + this.types = new ConcurrentHashMap<>(); + + Options options = new Options(); + options.createIfMissing(!path.exists()); + this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); + + byte[] versionData = db().get(STORE_VERSION_KEY); + if (versionData != null) { + long version = serializer.deserializeLong(versionData); + if (version != STORE_VERSION) { + throw new UnsupportedStoreVersionException(); + } + } else { + db().put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); + } + + Map aliases; + try { + aliases = get(TYPE_ALIASES_KEY, TypeAliases.class).aliases; + } catch (NoSuchElementException e) { + aliases = new HashMap<>(); + } + typeAliases = new ConcurrentHashMap<>(aliases); + } + + @Override + public T getMetadata(Class klass) throws Exception { + try { + return get(METADATA_KEY, klass); + } catch (NoSuchElementException nsee) { + return null; + } + } + + @Override + public void setMetadata(Object value) throws Exception { + if (value != null) { + put(METADATA_KEY, value); + } else { + db().delete(METADATA_KEY); + } + } + + T get(byte[] key, Class klass) throws Exception { + byte[] data = db().get(key); + if (data == null) { + throw new NoSuchElementException(new String(key, UTF_8)); + } + return serializer.deserialize(data, klass); + } + + private void put(byte[] key, Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + db().put(key, serializer.serialize(value)); + } + + @Override + public T read(Class klass, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + byte[] key = getTypeInfo(klass).naturalIndex().start(null, naturalKey); + return get(key, klass); + } + + @Override + public void write(Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + LevelDBTypeInfo ti = getTypeInfo(value.getClass()); + + try (WriteBatch batch = db().createWriteBatch()) { + byte[] data = serializer.serialize(value); + synchronized (ti) { + Object existing; + try { + existing = get(ti.naturalIndex().entityKey(null, value), value.getClass()); + } catch (NoSuchElementException e) { + existing = null; + } + + PrefixCache cache = new PrefixCache(value); + byte[] naturalKey = ti.naturalIndex().toKey(ti.naturalIndex().getValue(value)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + byte[] prefix = cache.getPrefix(idx); + idx.add(batch, value, existing, data, naturalKey, prefix); + } + db().write(batch); + } + } + } + + @Override + public void delete(Class type, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + try (WriteBatch batch = db().createWriteBatch()) { + LevelDBTypeInfo ti = getTypeInfo(type); + byte[] key = ti.naturalIndex().start(null, naturalKey); + synchronized (ti) { + byte[] data = db().get(key); + if (data != null) { + Object existing = serializer.deserialize(data, type); + PrefixCache cache = new PrefixCache(existing); + byte[] keyBytes = ti.naturalIndex().toKey(ti.naturalIndex().getValue(existing)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + idx.remove(batch, existing, keyBytes, cache.getPrefix(idx)); + } + db().write(batch); + } + } + } catch (NoSuchElementException nse) { + // Ignore. + } + } + + @Override + public KVStoreView view(Class type) throws Exception { + return new KVStoreView(type) { + @Override + public Iterator iterator() { + try { + return new LevelDBIterator<>(LevelDB.this, this); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + }; + } + + @Override + public long count(Class type) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); + return idx.getCount(idx.end(null)); + } + + @Override + public long count(Class type, String index, Object indexedValue) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).index(index); + return idx.getCount(idx.end(null, indexedValue)); + } + + @Override + public void close() throws IOException { + DB _db = this._db.getAndSet(null); + if (_db == null) { + return; + } + + try { + _db.close(); + } catch (IOException ioe) { + throw ioe; + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } + } + + /** Returns metadata about indices for the given type. */ + LevelDBTypeInfo getTypeInfo(Class type) throws Exception { + LevelDBTypeInfo ti = types.get(type); + if (ti == null) { + LevelDBTypeInfo tmp = new LevelDBTypeInfo(this, type, getTypeAlias(type)); + ti = types.putIfAbsent(type, tmp); + if (ti == null) { + ti = tmp; + } + } + return ti; + } + + /** + * Try to avoid use-after close since that has the tendency of crashing the JVM. This doesn't + * prevent methods that retrieved the instance from using it after close, but hopefully will + * catch most cases; otherwise, we'll need some kind of locking. + */ + DB db() { + DB _db = this._db.get(); + if (_db == null) { + throw new IllegalStateException("DB is closed."); + } + return _db; + } + + private byte[] getTypeAlias(Class klass) throws Exception { + byte[] alias = typeAliases.get(klass.getName()); + if (alias == null) { + synchronized (typeAliases) { + byte[] tmp = String.valueOf(typeAliases.size()).getBytes(UTF_8); + alias = typeAliases.putIfAbsent(klass.getName(), tmp); + if (alias == null) { + alias = tmp; + put(TYPE_ALIASES_KEY, new TypeAliases(typeAliases)); + } + } + } + return alias; + } + + /** Needs to be public for Jackson. */ + public static class TypeAliases { + + public Map aliases; + + TypeAliases(Map aliases) { + this.aliases = aliases; + } + + TypeAliases() { + this(null); + } + + } + + private static class PrefixCache { + + private final Object entity; + private final Map prefixes; + + PrefixCache(Object entity) { + this.entity = entity; + this.prefixes = new HashMap<>(); + } + + byte[] getPrefix(LevelDBTypeInfo.Index idx) throws Exception { + byte[] prefix = null; + if (idx.isChild()) { + prefix = prefixes.get(idx.parent()); + if (prefix == null) { + prefix = idx.parent().childPrefix(idx.parent().getValue(entity)); + prefixes.put(idx.parent(), prefix); + } + } + return prefix; + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java new file mode 100644 index 0000000000000..a5d0f9f4fb373 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.DBIterator; + +class LevelDBIterator implements KVStoreIterator { + + private final LevelDB db; + private final boolean ascending; + private final DBIterator it; + private final Class type; + private final LevelDBTypeInfo ti; + private final LevelDBTypeInfo.Index index; + private final byte[] indexKeyPrefix; + private final byte[] end; + private final long max; + + private boolean checkedNext; + private byte[] next; + private boolean closed; + private long count; + + LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { + this.db = db; + this.ascending = params.ascending; + this.it = db.db().iterator(); + this.type = params.type; + this.ti = db.getTypeInfo(type); + this.index = ti.index(params.index); + this.max = params.max; + + Preconditions.checkArgument(!index.isChild() || params.parent != null, + "Cannot iterate over child index %s without parent value.", params.index); + byte[] parent = index.isChild() ? index.parent().childPrefix(params.parent) : null; + + this.indexKeyPrefix = index.keyPrefix(parent); + + byte[] firstKey; + if (params.first != null) { + if (ascending) { + firstKey = index.start(parent, params.first); + } else { + firstKey = index.end(parent, params.first); + } + } else if (ascending) { + firstKey = index.keyPrefix(parent); + } else { + firstKey = index.end(parent); + } + it.seek(firstKey); + + byte[] end = null; + if (ascending) { + if (params.last != null) { + end = index.end(parent, params.last); + } else { + end = index.end(parent); + } + } else { + if (params.last != null) { + end = index.start(parent, params.last); + } + if (it.hasNext()) { + // When descending, the caller may have set up the start of iteration at a non-existant + // entry that is guaranteed to be after the desired entry. For example, if you have a + // compound key (a, b) where b is a, integer, you may seek to the end of the elements that + // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not + // exist in the database. So need to check here whether the next value actually belongs to + // the set being returned by the iterator before advancing. + byte[] nextKey = it.peekNext().getKey(); + if (compare(nextKey, indexKeyPrefix) <= 0) { + it.next(); + } + } + } + this.end = end; + + if (params.skip > 0) { + skip(params.skip); + } + } + + @Override + public boolean hasNext() { + if (!checkedNext && !closed) { + next = loadNext(); + checkedNext = true; + } + if (!closed && next == null) { + try { + close(); + } catch (IOException ioe) { + throw Throwables.propagate(ioe); + } + } + return next != null; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + checkedNext = false; + + try { + T ret; + if (index == null || index.isCopy()) { + ret = db.serializer.deserialize(next, type); + } else { + byte[] key = ti.buildKey(false, ti.naturalIndex().keyPrefix(null), next); + ret = db.get(key, type); + } + next = null; + return ret; + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public List next(int max) { + List list = new ArrayList<>(max); + while (hasNext() && list.size() < max) { + list.add(next()); + } + return list; + } + + @Override + public boolean skip(long n) { + long skipped = 0; + while (skipped < n) { + if (next != null) { + checkedNext = false; + next = null; + skipped++; + continue; + } + + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + checkedNext = true; + return false; + } + + Map.Entry e = ascending ? it.next() : it.prev(); + if (!isEndMarker(e.getKey())) { + skipped++; + } + } + + return hasNext(); + } + + @Override + public synchronized void close() throws IOException { + if (!closed) { + it.close(); + closed = true; + } + } + + private byte[] loadNext() { + if (count >= max) { + return null; + } + + try { + while (true) { + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return null; + } + + Map.Entry nextEntry; + try { + // Avoid races if another thread is updating the DB. + nextEntry = ascending ? it.next() : it.prev(); + } catch (NoSuchElementException e) { + return null; + } + + byte[] nextKey = nextEntry.getKey(); + // Next key is not part of the index, stop. + if (!startsWith(nextKey, indexKeyPrefix)) { + return null; + } + + // If the next key is an end marker, then skip it. + if (isEndMarker(nextKey)) { + continue; + } + + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } + } + + count++; + + // Next element is part of the iteration, return it. + return nextEntry.getValue(); + } + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @VisibleForTesting + static boolean startsWith(byte[] key, byte[] prefix) { + if (key.length < prefix.length) { + return false; + } + + for (int i = 0; i < prefix.length; i++) { + if (key[i] != prefix[i]) { + return false; + } + } + + return true; + } + + private boolean isEndMarker(byte[] key) { + return (key.length > 2 && + key[key.length - 2] == LevelDBTypeInfo.KEY_SEPARATOR && + key[key.length - 1] == LevelDBTypeInfo.END_MARKER[0]); + } + + static int compare(byte[] a, byte[] b) { + int diff = 0; + int minLen = Math.min(a.length, b.length); + for (int i = 0; i < minLen; i++) { + diff += (a[i] - b[i]); + if (diff != 0) { + return diff; + } + } + + return a.length - b.length; + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java new file mode 100644 index 0000000000000..3ab17dbd03ca7 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.WriteBatch; + +/** + * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected + * via reflection, to make it cheaper to access it multiple times. + * + *

+ * The hierarchy of keys stored in LevelDB looks roughly like the following. This hierarchy ensures + * that iteration over indices is easy, and that updating values in the store is not overly + * expensive. Of note, indices choose using more disk space (one value per key) instead of keeping + * lists of pointers, which would be more expensive to update at runtime. + *

+ * + *

+ * Indentation defines when a sub-key lives under a parent key. In LevelDB, this means the full + * key would be the concatenation of everything up to that point in the hierarchy, with each + * component separated by a NULL byte. + *

+ * + *
+ * +TYPE_NAME
+ *   NATURAL_INDEX
+ *     +NATURAL_KEY
+ *     -
+ *   -NATURAL_INDEX
+ *   INDEX_NAME
+ *     +INDEX_VALUE
+ *       +NATURAL_KEY
+ *     -INDEX_VALUE
+ *     .INDEX_VALUE
+ *       CHILD_INDEX_NAME
+ *         +CHILD_INDEX_VALUE
+ *           NATURAL_KEY_OR_DATA
+ *         -
+ *   -INDEX_NAME
+ * 
+ * + *

+ * Entity data (either the entity's natural key or a copy of the data) is stored in all keys + * that end with "+". A count of all objects that match a particular top-level index + * value is kept at the end marker ("-"). A count is also kept at the natural index's end + * marker, to make it easy to retrieve the number of all elements of a particular type. + *

+ * + *

+ * To illustrate, given a type "Foo", with a natural index and a second index called "bar", you'd + * have these keys and values in the store for two instances, one with natural key "key1" and the + * other "key2", both with value "yes" for "bar": + *

+ * + *
+ * Foo __main__ +key1   [data for instance 1]
+ * Foo __main__ +key2   [data for instance 2]
+ * Foo __main__ -       [count of all Foo]
+ * Foo bar +yes +key1   [instance 1 key or data, depending on index type]
+ * Foo bar +yes +key2   [instance 2 key or data, depending on index type]
+ * Foo bar +yes -       [count of all Foo with "bar=yes" ]
+ * 
+ * + *

+ * Note that all indexed values are prepended with "+", even if the index itself does not have an + * explicit end marker. This allows for easily skipping to the end of an index by telling LevelDB + * to seek to the "phantom" end marker of the index. Throughout the code and comments, this part + * of the full LevelDB key is generally referred to as the "index value" of the entity. + *

+ * + *

+ * Child indices are stored after their parent index. In the example above, let's assume there is + * a child index "child", whose parent is "bar". If both instances have value "no" for this field, + * the data in the store would look something like the following: + *

+ * + *
+ * ...
+ * Foo bar +yes -
+ * Foo bar .yes .child +no +key1   [instance 1 key or data, depending on index type]
+ * Foo bar .yes .child +no +key2   [instance 2 key or data, depending on index type]
+ * ...
+ * 
+ */ +class LevelDBTypeInfo { + + static final byte[] END_MARKER = new byte[] { '-' }; + static final byte ENTRY_PREFIX = (byte) '+'; + static final byte KEY_SEPARATOR = 0x0; + static byte TRUE = (byte) '1'; + static byte FALSE = (byte) '0'; + + private static final byte SECONDARY_IDX_PREFIX = (byte) '.'; + private static final byte POSITIVE_MARKER = (byte) '='; + private static final byte NEGATIVE_MARKER = (byte) '*'; + private static final byte[] HEX_BYTES = new byte[] { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' + }; + + private final LevelDB db; + private final Class type; + private final Map indices; + private final byte[] typePrefix; + + LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { + this.db = db; + this.type = type; + this.indices = new HashMap<>(); + + KVTypeInfo ti = new KVTypeInfo(type); + + // First create the parent indices, then the child indices. + ti.indices().forEach(idx -> { + if (idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), null)); + } + }); + ti.indices().forEach(idx -> { + if (!idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), + indices.get(idx.parent()))); + } + }); + + this.typePrefix = alias; + } + + Class type() { + return type; + } + + byte[] keyPrefix() { + return typePrefix; + } + + Index naturalIndex() { + return index(KVIndex.NATURAL_INDEX_NAME); + } + + Index index(String name) { + Index i = indices.get(name); + Preconditions.checkArgument(i != null, "Index %s does not exist for type %s.", name, + type.getName()); + return i; + } + + Collection indices() { + return indices.values(); + } + + byte[] buildKey(byte[]... components) { + return buildKey(true, components); + } + + byte[] buildKey(boolean addTypePrefix, byte[]... components) { + int len = 0; + if (addTypePrefix) { + len += typePrefix.length + 1; + } + for (byte[] comp : components) { + len += comp.length; + } + len += components.length - 1; + + byte[] dest = new byte[len]; + int written = 0; + + if (addTypePrefix) { + System.arraycopy(typePrefix, 0, dest, 0, typePrefix.length); + dest[typePrefix.length] = KEY_SEPARATOR; + written += typePrefix.length + 1; + } + + for (byte[] comp : components) { + System.arraycopy(comp, 0, dest, written, comp.length); + written += comp.length; + if (written < dest.length) { + dest[written] = KEY_SEPARATOR; + written++; + } + } + + return dest; + } + + /** + * Models a single index in LevelDB. See top-level class's javadoc for a description of how the + * keys are generated. + */ + class Index { + + private final boolean copy; + private final boolean isNatural; + private final byte[] name; + private final KVTypeInfo.Accessor accessor; + private final Index parent; + + private Index(KVIndex self, KVTypeInfo.Accessor accessor, Index parent) { + byte[] name = self.value().getBytes(UTF_8); + if (parent != null) { + byte[] child = new byte[name.length + 1]; + child[0] = SECONDARY_IDX_PREFIX; + System.arraycopy(name, 0, child, 1, name.length); + } + + this.name = name; + this.isNatural = self.value().equals(KVIndex.NATURAL_INDEX_NAME); + this.copy = isNatural || self.copy(); + this.accessor = accessor; + this.parent = parent; + } + + boolean isCopy() { + return copy; + } + + boolean isChild() { + return parent != null; + } + + Index parent() { + return parent; + } + + /** + * Creates a key prefix for child indices of this index. This allows the prefix to be + * calculated only once, avoiding redundant work when multiple child indices of the + * same parent index exist. + */ + byte[] childPrefix(Object value) throws Exception { + Preconditions.checkState(parent == null, "Not a parent index."); + return buildKey(name, toParentKey(value)); + } + + /** + * Gets the index value for a particular entity (which is the value of the field or method + * tagged with the index annotation). This is used as part of the LevelDB key where the + * entity (or its id) is stored. + */ + Object getValue(Object entity) throws Exception { + return accessor.get(entity); + } + + private void checkParent(byte[] prefix) { + if (prefix != null) { + Preconditions.checkState(parent != null, "Parent prefix provided for parent index."); + } else { + Preconditions.checkState(parent == null, "Parent prefix missing for child index."); + } + } + + /** The prefix for all keys that belong to this index. */ + byte[] keyPrefix(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name) : buildKey(name); + } + + /** + * The key where to start ascending iteration for entities whose value for the indexed field + * match the given value. + */ + byte[] start(byte[] prefix, Object value) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value)) + : buildKey(name, toKey(value)); + } + + /** The key for the index's end marker. */ + byte[] end(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, END_MARKER) + : buildKey(name, END_MARKER); + } + + /** The key for the end marker for entries with the given value. */ + byte[] end(byte[] prefix, Object value) throws Exception { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value), END_MARKER) + : buildKey(name, toKey(value), END_MARKER); + } + + /** The full key in the index that identifies the given entity. */ + byte[] entityKey(byte[] prefix, Object entity) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, toKey(naturalIndex().getValue(entity))); + } + return entityKey; + } + + private void updateCount(WriteBatch batch, byte[] key, long delta) throws Exception { + long updated = getCount(key) + delta; + if (updated > 0) { + batch.put(key, db.serializer.serialize(updated)); + } else { + batch.delete(key); + } + } + + private void addOrRemove( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, naturalKey); + } + + boolean needCountUpdate = (existing == null); + + // Check whether there's a need to update the index. The index needs to be updated in two + // cases: + // + // - There is no existing value for the entity, so a new index value will be added. + // - If there is a previously stored value for the entity, and the index value for the + // current index does not match the new value, the old entry needs to be deleted and + // the new one added. + // + // Natural indices don't need to be checked, because by definition both old and new entities + // will have the same key. The put() call is all that's needed in that case. + // + // Also check whether we need to update the counts. If the indexed value is changing, we + // need to decrement the count at the old index value, and the new indexed value count needs + // to be incremented. + if (existing != null && !isNatural) { + byte[] oldPrefix = null; + Object oldIndexedValue = getValue(existing); + boolean removeExisting = !indexValue.equals(oldIndexedValue); + if (!removeExisting && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + removeExisting = LevelDBIterator.compare(prefix, oldPrefix) != 0; + } + + if (removeExisting) { + if (oldPrefix == null && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + } + + byte[] oldKey = entityKey(oldPrefix, existing); + batch.delete(oldKey); + + // If the indexed value has changed, we need to update the counts at the old and new + // end markers for the indexed value. + if (!isChild()) { + byte[] oldCountKey = end(null, oldIndexedValue); + updateCount(batch, oldCountKey, -1L); + needCountUpdate = true; + } + } + } + + if (data != null) { + byte[] stored = copy ? data : naturalKey; + batch.put(entityKey, stored); + } else { + batch.delete(entityKey); + } + + if (needCountUpdate && !isChild()) { + long delta = data != null ? 1L : -1L; + byte[] countKey = isNatural ? end(prefix) : end(prefix, indexValue); + updateCount(batch, countKey, delta); + } + } + + /** + * Add an entry to the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being added to the index. + * @param existing The entity being replaced in the index, or null. + * @param data Serialized entity to store (when storing the entity, not a reference). + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void add( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, existing, data, naturalKey, prefix); + } + + /** + * Remove a value from the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being removed, to identify the index entry to modify. + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void remove( + WriteBatch batch, + Object entity, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, null, null, naturalKey, prefix); + } + + long getCount(byte[] key) throws Exception { + byte[] data = db.db().get(key); + return data != null ? db.serializer.deserializeLong(data) : 0; + } + + byte[] toParentKey(Object value) { + return toKey(value, SECONDARY_IDX_PREFIX); + } + + byte[] toKey(Object value) { + return toKey(value, ENTRY_PREFIX); + } + + /** + * Translates a value to be used as part of the store key. + * + * Integral numbers are encoded as a string in a way that preserves lexicographical + * ordering. The string is prepended with a marker telling whether the number is negative + * or positive ("*" for negative and "=" for positive are used since "-" and "+" have the + * opposite of the desired order), and then the number is encoded into a hex string (so + * it occupies twice the number of bytes as the original type). + * + * Arrays are encoded by encoding each element separately, separated by KEY_SEPARATOR. + */ + byte[] toKey(Object value, byte prefix) { + final byte[] result; + + if (value instanceof String) { + byte[] str = ((String) value).getBytes(UTF_8); + result = new byte[str.length + 1]; + result[0] = prefix; + System.arraycopy(str, 0, result, 1, str.length); + } else if (value instanceof Boolean) { + result = new byte[] { prefix, (Boolean) value ? TRUE : FALSE }; + } else if (value.getClass().isArray()) { + int length = Array.getLength(value); + byte[][] components = new byte[length][]; + for (int i = 0; i < length; i++) { + components[i] = toKey(Array.get(value, i)); + } + result = buildKey(false, components); + } else { + int bytes; + + if (value instanceof Integer) { + bytes = Integer.SIZE; + } else if (value instanceof Long) { + bytes = Long.SIZE; + } else if (value instanceof Short) { + bytes = Short.SIZE; + } else if (value instanceof Byte) { + bytes = Byte.SIZE; + } else { + throw new IllegalArgumentException(String.format("Type %s not allowed as key.", + value.getClass().getName())); + } + + bytes = bytes / Byte.SIZE; + + byte[] key = new byte[bytes * 2 + 2]; + long longValue = ((Number) value).longValue(); + key[0] = prefix; + key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + + for (int i = 0; i < key.length - 2; i++) { + int masked = (int) ((longValue >>> (4 * i)) & 0xF); + key[key.length - i - 1] = HEX_BYTES[masked]; + } + + result = key; + } + + return result; + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java new file mode 100644 index 0000000000000..2ed246e4f4c97 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; + +/** + * Exception thrown when the store implementation is not compatible with the underlying data. + */ +public class UnsupportedStoreVersionException extends IOException { + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java new file mode 100644 index 0000000000000..afb72b8689223 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import com.google.common.base.Objects; + +public class CustomType1 { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex(value = "name", copy = true) + public String name; + + @KVIndex("int") + public int num; + + @KVIndex(value = "child", parent = "id") + public String child; + + @Override + public boolean equals(Object o) { + if (o instanceof CustomType1) { + CustomType1 other = (CustomType1) o; + return id.equals(other.id) && name.equals(other.name); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("key", key) + .add("id", id) + .add("name", name) + .add("num", num) + .toString(); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java new file mode 100644 index 0000000000000..8549712213393 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -0,0 +1,506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +public abstract class DBIteratorSuite { + + private static final Logger LOG = LoggerFactory.getLogger(DBIteratorSuite.class); + + private static final int MIN_ENTRIES = 42; + private static final int MAX_ENTRIES = 1024; + private static final Random RND = new Random(); + + private static List allEntries; + private static List clashingEntries; + private static KVStore db; + + private static interface BaseComparator extends Comparator { + /** + * Returns a comparator that falls back to natural order if this comparator's ordering + * returns equality for two elements. Used to mimic how the index sorts things internally. + */ + default BaseComparator fallback() { + return (t1, t2) -> { + int result = BaseComparator.this.compare(t1, t2); + if (result != 0) { + return result; + } + + return t1.key.compareTo(t2.key); + }; + } + + /** Reverses the order of this comparator. */ + default BaseComparator reverse() { + return (t1, t2) -> -BaseComparator.this.compare(t1, t2); + } + } + + private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); + private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); + private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); + + /** + * Implementations should override this method; it is called only once, before all tests are + * run. Any state can be safely stored in static variables and cleaned up in a @AfterClass + * handler. + */ + protected abstract KVStore createStore() throws Exception; + + @BeforeClass + public static void setupClass() { + long seed = RND.nextLong(); + LOG.info("Random seed: {}", seed); + RND.setSeed(seed); + } + + @AfterClass + public static void cleanupData() throws Exception { + allEntries = null; + db = null; + } + + @Before + public void setup() throws Exception { + if (db != null) { + return; + } + + db = createStore(); + + int count = RND.nextInt(MAX_ENTRIES) + MIN_ENTRIES; + + allEntries = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + RND.nextInt(MAX_ENTRIES); + t.num = RND.nextInt(MAX_ENTRIES); + t.child = "child" + (i % MIN_ENTRIES); + allEntries.add(t); + } + + // Shuffle the entries to avoid the insertion order matching the natural ordering. Just in case. + Collections.shuffle(allEntries, RND); + for (CustomType1 e : allEntries) { + db.write(e); + } + + // Pick the first generated value, and forcefully create a few entries that will clash + // with the indexed values (id and name), to make sure the index behaves correctly when + // multiple entities are indexed by the same value. + // + // This also serves as a test for the test code itself, to make sure it's sorting indices + // the same way the store is expected to. + CustomType1 first = allEntries.get(0); + clashingEntries = new ArrayList<>(); + + int clashCount = RND.nextInt(MIN_ENTRIES) + 1; + for (int i = 0; i < clashCount; i++) { + CustomType1 t = new CustomType1(); + t.key = "n-key" + (count + i); + t.id = first.id; + t.name = first.name; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + clashingEntries.add(t); + db.write(t); + } + + // Create another entry that could cause problems: take the first entry, and make its indexed + // name be an extension of the existing ones, to make sure the implementation sorts these + // correctly even considering the separator character (shorter strings first). + CustomType1 t = new CustomType1(); + t.key = "extended-key-0"; + t.id = first.id; + t.name = first.name + "a"; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + db.write(t); + } + + @Test + public void naturalIndex() throws Exception { + testIteration(NATURAL_ORDER, view(), null, null); + } + + @Test + public void refIndex() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id"), null, null); + } + + @Test + public void copyIndex() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name"), null, null); + } + + @Test + public void numericIndex() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null, null); + } + + @Test + public void childIndex() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id), null, null); + } + + @Test + public void naturalIndexDescending() throws Exception { + testIteration(NATURAL_ORDER, view().reverse(), null, null); + } + + @Test + public void refIndexDescending() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null, null); + } + + @Test + public void copyIndexDescending() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null, null); + } + + @Test + public void numericIndexDescending() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null, null); + } + + @Test + public void childIndexDescending() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).reverse(), null, null); + } + + @Test + public void naturalIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().first(first.key), first, null); + } + + @Test + public void refIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first, null); + } + + @Test + public void childIndexWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).first(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first, null); + } + + @Test + public void refIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), first, null); + } + + @Test + public void childIndexDescendingWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, + view().index("child").parent(any.id).first(any.child).reverse(), null, null); + } + + @Test + public void naturalIndexWithSkip() throws Exception { + testIteration(NATURAL_ORDER, view().skip(pickCount()), null, null); + } + + @Test + public void refIndexWithSkip() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").skip(pickCount()), null, null); + } + + @Test + public void copyIndexWithSkip() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").skip(pickCount()), null, null); + } + + @Test + public void childIndexWithSkip() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).skip(pickCount()), + null, null); + } + + @Test + public void naturalIndexWithMax() throws Exception { + testIteration(NATURAL_ORDER, view().max(pickCount()), null, null); + } + + @Test + public void copyIndexWithMax() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").max(pickCount()), null, null); + } + + @Test + public void childIndexWithMax() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).max(pickCount()), null, + null); + } + + @Test + public void naturalIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().last(last.key), null, last); + } + + @Test + public void refIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").last(last.name), null, last); + } + + @Test + public void numericIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").last(last.num), null, last); + } + + @Test + public void childIndexWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().last(last.key), null, last); + } + + @Test + public void refIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").last(last.name), + null, last); + } + + @Test + public void numericIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").last(last.num), + null, last); + } + + @Test + public void childIndexDescendingWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child).reverse(), + null, null); + } + + @Test + public void testRefWithIntNaturalKey() throws Exception { + LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); + i.key = 1; + i.id = "1"; + i.values = Arrays.asList("1"); + + db.write(i); + + try(KVStoreIterator it = db.view(i.getClass()).closeableIterator()) { + Object read = it.next(); + assertEquals(i, read); + } + } + + private CustomType1 pickLimit() { + // Picks an element that has clashes with other elements in the given index. + return clashingEntries.get(RND.nextInt(clashingEntries.size())); + } + + private int pickCount() { + int count = RND.nextInt(allEntries.size() / 2); + return Math.max(count, 1); + } + + /** + * Compares the two values and falls back to comparing the natural key of CustomType1 + * if they're the same, to mimic the behavior of the indexing code. + */ + private > int compareWithFallback( + T v1, + T v2, + CustomType1 ct1, + CustomType1 ct2) { + int result = v1.compareTo(v2); + if (result != 0) { + return result; + } + + return ct1.key.compareTo(ct2.key); + } + + private void testIteration( + final BaseComparator order, + final KVStoreView params, + final CustomType1 first, + final CustomType1 last) throws Exception { + List indexOrder = sortBy(order.fallback()); + if (!params.ascending) { + indexOrder = Lists.reverse(indexOrder); + } + + Iterable expected = indexOrder; + BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + + if (params.parent != null) { + expected = Iterables.filter(expected, v -> params.parent.equals(v.id)); + } + + if (first != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); + } + + if (last != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(v, last) <= 0); + } + + if (params.skip > 0) { + expected = Iterables.skip(expected, (int) params.skip); + } + + if (params.max != Long.MAX_VALUE) { + expected = Iterables.limit(expected, (int) params.max); + } + + List actual = collect(params); + compareLists(expected, actual); + } + + /** Could use assertEquals(), but that creates hard to read errors for large lists. */ + private void compareLists(Iterable expected, List actual) { + Iterator expectedIt = expected.iterator(); + Iterator actualIt = actual.iterator(); + + int count = 0; + while (expectedIt.hasNext()) { + if (!actualIt.hasNext()) { + break; + } + count++; + assertEquals(expectedIt.next(), actualIt.next()); + } + + String message; + Object[] remaining; + int expectedCount = count; + int actualCount = count; + + if (expectedIt.hasNext()) { + remaining = Iterators.toArray(expectedIt, Object.class); + expectedCount += remaining.length; + message = "missing"; + } else { + remaining = Iterators.toArray(actualIt, Object.class); + actualCount += remaining.length; + message = "stray"; + } + + assertEquals(String.format("Found %s elements: %s", message, Arrays.asList(remaining)), + expectedCount, actualCount); + } + + private KVStoreView view() throws Exception { + return db.view(CustomType1.class); + } + + private List collect(KVStoreView view) throws Exception { + return Arrays.asList(Iterables.toArray(view, CustomType1.class)); + } + + private List sortBy(Comparator comp) { + List copy = new ArrayList<>(allEntries); + Collections.sort(copy, comp); + return copy; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java new file mode 100644 index 0000000000000..5e33606b12dd4 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Slf4jReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * A set of small benchmarks for the LevelDB implementation. + * + * The benchmarks are run over two different types (one with just a natural index, and one + * with a ref index), over a set of 2^20 elements, and the following tests are performed: + * + * - write (then update) elements in sequential natural key order + * - write (then update) elements in random natural key order + * - iterate over natural index, ascending and descending + * - iterate over ref index, ascending and descending + */ +@Ignore +public class LevelDBBenchmark { + + private static final int COUNT = 1024; + private static final AtomicInteger IDGEN = new AtomicInteger(); + private static final MetricRegistry metrics = new MetricRegistry(); + private static final Timer dbCreation = metrics.timer("dbCreation"); + private static final Timer dbClose = metrics.timer("dbClose"); + + private LevelDB db; + private File dbpath; + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + try(Timer.Context ctx = dbCreation.time()) { + db = new LevelDB(dbpath); + } + } + + @After + public void cleanup() throws Exception { + if (db != null) { + try(Timer.Context ctx = dbClose.time()) { + db.close(); + } + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @AfterClass + public static void report() { + if (metrics.getTimers().isEmpty()) { + return; + } + + int headingPrefix = 0; + for (Map.Entry e : metrics.getTimers().entrySet()) { + headingPrefix = Math.max(e.getKey().length(), headingPrefix); + } + headingPrefix += 4; + + StringBuilder heading = new StringBuilder(); + for (int i = 0; i < headingPrefix; i++) { + heading.append(" "); + } + heading.append("\tcount"); + heading.append("\tmean"); + heading.append("\tmin"); + heading.append("\tmax"); + heading.append("\t95th"); + System.out.println(heading); + + for (Map.Entry e : metrics.getTimers().entrySet()) { + StringBuilder row = new StringBuilder(); + row.append(e.getKey()); + for (int i = 0; i < headingPrefix - e.getKey().length(); i++) { + row.append(" "); + } + + Snapshot s = e.getValue().getSnapshot(); + row.append("\t").append(e.getValue().getCount()); + row.append("\t").append(toMs(s.getMean())); + row.append("\t").append(toMs(s.getMin())); + row.append("\t").append(toMs(s.getMax())); + row.append("\t").append(toMs(s.get95thPercentile())); + + System.out.println(row); + } + + Slf4jReporter.forRegistry(metrics).outputTo(LoggerFactory.getLogger(LevelDBBenchmark.class)) + .build().report(); + } + + private static String toMs(double nanos) { + return String.format("%.3f", nanos / 1000 / 1000); + } + + @Test + public void sequentialWritesNoIndex() throws Exception { + List entries = createSimpleType(); + writeAll(entries, "sequentialWritesNoIndex"); + writeAll(entries, "sequentialUpdatesNoIndex"); + deleteNoIndex(entries, "sequentialDeleteNoIndex"); + } + + @Test + public void randomWritesNoIndex() throws Exception { + List entries = createSimpleType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesNoIndex"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesNoIndex"); + + Collections.shuffle(entries); + deleteNoIndex(entries, "randomDeletesNoIndex"); + } + + @Test + public void sequentialWritesIndexedType() throws Exception { + List entries = createIndexedType(); + writeAll(entries, "sequentialWritesIndexed"); + writeAll(entries, "sequentialUpdatesIndexed"); + deleteIndexed(entries, "sequentialDeleteIndexed"); + } + + @Test + public void randomWritesIndexedTypeAndIteration() throws Exception { + List entries = createIndexedType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesIndexed"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesIndexed"); + + // Run iteration benchmarks here since we've gone through the trouble of writing all + // the data already. + KVStoreView view = db.view(IndexedType.class); + iterate(view, "naturalIndex"); + iterate(view.reverse(), "naturalIndexDescending"); + iterate(view.index("name"), "refIndex"); + iterate(view.index("name").reverse(), "refIndexDescending"); + + Collections.shuffle(entries); + deleteIndexed(entries, "randomDeleteIndexed"); + } + + private void iterate(KVStoreView view, String name) throws Exception { + Timer create = metrics.timer(name + "CreateIterator"); + Timer iter = metrics.timer(name + "Iteration"); + KVStoreIterator it = null; + { + // Create the iterator several times, just to have multiple data points. + for (int i = 0; i < 1024; i++) { + if (it != null) { + it.close(); + } + try(Timer.Context ctx = create.time()) { + it = view.closeableIterator(); + } + } + } + + for (; it.hasNext(); ) { + try(Timer.Context ctx = iter.time()) { + it.next(); + } + } + } + + private void writeAll(List entries, String timerName) throws Exception { + Timer timer = newTimer(timerName); + for (Object o : entries) { + try(Timer.Context ctx = timer.time()) { + db.write(o); + } + } + } + + private void deleteNoIndex(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (SimpleType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private void deleteIndexed(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (IndexedType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private List createSimpleType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + SimpleType t = new SimpleType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private List createIndexedType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + IndexedType t = new IndexedType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private Timer newTimer(String name) { + assertNull("Timer already exists: " + name, metrics.getTimers().get(name)); + return metrics.timer(name); + } + + public static class SimpleType { + + @KVIndex + public int key; + + public String name; + + } + + public static class IndexedType { + + @KVIndex + public int key; + + @KVIndex("name") + public String name; + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java new file mode 100644 index 0000000000000..93409712986ca --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; + +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; + +public class LevelDBIteratorSuite extends DBIteratorSuite { + + private static File dbpath; + private static LevelDB db; + + @AfterClass + public static void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Override + protected KVStore createStore() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + return db; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java new file mode 100644 index 0000000000000..ee1c397c08573 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.apache.commons.io.FileUtils; +import org.iq80.leveldb.DBIterator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBSuite { + + private LevelDB db; + private File dbpath; + + @After + public void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + } + + @Test + public void testReopenAndVersionCheckDb() throws Exception { + db.close(); + db = null; + assertTrue(dbpath.exists()); + + db = new LevelDB(dbpath); + assertEquals(LevelDB.STORE_VERSION, + db.serializer.deserializeLong(db.db().get(LevelDB.STORE_VERSION_KEY))); + db.db().put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); + db.close(); + db = null; + + try { + db = new LevelDB(dbpath); + fail("Should have failed version check."); + } catch (UnsupportedStoreVersionException e) { + // Expected. + } + } + + @Test + public void testObjectWriteReadDelete() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + try { + db.read(CustomType1.class, t.key); + fail("Expected exception for non-existant object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + db.write(t); + assertEquals(t, db.read(t.getClass(), t.key)); + assertEquals(1L, db.count(t.getClass())); + + db.delete(t.getClass(), t.key); + try { + db.read(t.getClass(), t.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + // Look into the actual DB and make sure that all the keys related to the type have been + // removed. + assertEquals(0, countKeys(t.getClass())); + } + + @Test + public void testMultipleObjectWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.id = "id"; + t2.name = "name2"; + t2.child = "child2"; + + db.write(t1); + db.write(t2); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(2L, db.count(t1.getClass())); + + // There should be one "id" index entry with two values. + assertEquals(2, db.count(t1.getClass(), "id", t1.id)); + + // Delete the first entry; now there should be 3 remaining keys, since one of the "name" + // index entries should have been removed. + db.delete(t1.getClass(), t1.key); + + // Make sure there's a single entry in the "id" index now. + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + + // Delete the remaining entry, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + } + + @Test + public void testMultipleTypesWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + IntKeyType t2 = new IntKeyType(); + t2.key = 2; + t2.id = "2"; + t2.values = Arrays.asList("value1", "value2"); + + ArrayKeyIndexType t3 = new ArrayKeyIndexType(); + t3.key = new int[] { 42, 84 }; + t3.id = new String[] { "id1", "id2" }; + + db.write(t1); + db.write(t2); + db.write(t3); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(t3, db.read(t3.getClass(), t3.key)); + + // There should be one "id" index with a single entry for each type. + assertEquals(1, db.count(t1.getClass(), "id", t1.id)); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the first entry; this should not affect the entries for the second type. + db.delete(t1.getClass(), t1.key); + assertEquals(0, countKeys(t1.getClass())); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the remaining entries, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + + db.delete(t3.getClass(), t3.key); + assertEquals(0, countKeys(t3.getClass())); + } + + @Test + public void testMetadata() throws Exception { + assertNull(db.getMetadata(CustomType1.class)); + + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.setMetadata(t); + assertEquals(t, db.getMetadata(CustomType1.class)); + + db.setMetadata(null); + assertNull(db.getMetadata(CustomType1.class)); + } + + @Test + public void testUpdate() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.write(t); + + t.name = "anotherName"; + + db.write(t); + + assertEquals(1, db.count(t.getClass())); + assertEquals(1, db.count(t.getClass(), "name", "anotherName")); + assertEquals(0, db.count(t.getClass(), "name", "name")); + } + + @Test + public void testSkip() throws Exception { + for (int i = 0; i < 10; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.child = "child" + i; + + db.write(t); + } + + KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); + assertTrue(it.hasNext()); + assertTrue(it.skip(5)); + assertEquals("key5", it.next().key); + assertTrue(it.skip(3)); + assertEquals("key9", it.next().key); + assertFalse(it.hasNext()); + } + + private int countKeys(Class type) throws Exception { + byte[] prefix = db.getTypeInfo(type).keyPrefix(); + int count = 0; + + DBIterator it = db.db().iterator(); + it.seek(prefix); + + while (it.hasNext()) { + byte[] key = it.next().getKey(); + if (LevelDBIterator.startsWith(key, prefix)) { + count++; + } + } + + return count; + } + + public static class IntKeyType { + + @KVIndex + public int key; + + @KVIndex("id") + public String id; + + public List values; + + @Override + public boolean equals(Object o) { + if (o instanceof IntKeyType) { + IntKeyType other = (IntKeyType) o; + return key == other.key && id.equals(other.id) && values.equals(other.values); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + } + + public static class ArrayKeyIndexType { + + @KVIndex + public int[] key; + + @KVIndex("id") + public String[] id; + + @Override + public boolean equals(Object o) { + if (o instanceof ArrayKeyIndexType) { + ArrayKeyIndexType other = (ArrayKeyIndexType) o; + return Arrays.equals(key, other.key) && Arrays.equals(id, other.id); + } + return false; + } + + @Override + public int hashCode() { + return key.hashCode(); + } + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java new file mode 100644 index 0000000000000..8e6196506c6a8 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBTypeInfoSuite { + + @Test + public void testIndexAnnotation() throws Exception { + KVTypeInfo ti = new KVTypeInfo(CustomType1.class); + assertEquals(5, ti.indices().count()); + + CustomType1 t1 = new CustomType1(); + t1.key = "key"; + t1.id = "id"; + t1.name = "name"; + t1.num = 42; + t1.child = "child"; + + assertEquals(t1.key, ti.getIndexValue(KVIndex.NATURAL_INDEX_NAME, t1)); + assertEquals(t1.id, ti.getIndexValue("id", t1)); + assertEquals(t1.name, ti.getIndexValue("name", t1)); + assertEquals(t1.num, ti.getIndexValue("int", t1)); + assertEquals(t1.child, ti.getIndexValue("child", t1)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex() throws Exception { + newTypeInfo(NoNaturalIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex2() throws Exception { + newTypeInfo(NoNaturalIndex2.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testDuplicateIndex() throws Exception { + newTypeInfo(DuplicateIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyIndexName() throws Exception { + newTypeInfo(EmptyIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexName() throws Exception { + newTypeInfo(IllegalIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexMethod() throws Exception { + newTypeInfo(IllegalIndexMethod.class); + } + + @Test + public void testKeyClashes() throws Exception { + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.name = "a"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.name = "aa"; + + CustomType1 t3 = new CustomType1(); + t3.key = "key3"; + t3.name = "aaa"; + + // Make sure entries with conflicting names are sorted correctly. + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t2)); + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t3)); + assertBefore(ti.index("name").entityKey(null, t2), ti.index("name").entityKey(null, t3)); + } + + @Test + public void testNumEncoding() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertEquals("+=00000001", new String(idx.toKey(1), UTF_8)); + assertEquals("+=00000010", new String(idx.toKey(16), UTF_8)); + assertEquals("+=7fffffff", new String(idx.toKey(Integer.MAX_VALUE), UTF_8)); + + assertBefore(idx.toKey(1), idx.toKey(2)); + assertBefore(idx.toKey(-1), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(-1)); + assertBefore(idx.toKey(1), idx.toKey(11)); + assertBefore(idx.toKey(Integer.MIN_VALUE), idx.toKey(Integer.MAX_VALUE)); + + assertBefore(idx.toKey(1L), idx.toKey(2L)); + assertBefore(idx.toKey(-1L), idx.toKey(2L)); + assertBefore(idx.toKey(Long.MIN_VALUE), idx.toKey(Long.MAX_VALUE)); + + assertBefore(idx.toKey((short) 1), idx.toKey((short) 2)); + assertBefore(idx.toKey((short) -1), idx.toKey((short) 2)); + assertBefore(idx.toKey(Short.MIN_VALUE), idx.toKey(Short.MAX_VALUE)); + + assertBefore(idx.toKey((byte) 1), idx.toKey((byte) 2)); + assertBefore(idx.toKey((byte) -1), idx.toKey((byte) 2)); + assertBefore(idx.toKey(Byte.MIN_VALUE), idx.toKey(Byte.MAX_VALUE)); + + byte prefix = LevelDBTypeInfo.ENTRY_PREFIX; + assertSame(new byte[] { prefix, LevelDBTypeInfo.FALSE }, idx.toKey(false)); + assertSame(new byte[] { prefix, LevelDBTypeInfo.TRUE }, idx.toKey(true)); + } + + @Test + public void testArrayIndices() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertBefore(idx.toKey(new String[] { "str1" }), idx.toKey(new String[] { "str2" })); + assertBefore(idx.toKey(new String[] { "str1", "str2" }), + idx.toKey(new String[] { "str1", "str3" })); + + assertBefore(idx.toKey(new int[] { 1 }), idx.toKey(new int[] { 2 })); + assertBefore(idx.toKey(new int[] { 1, 2 }), idx.toKey(new int[] { 1, 3 })); + } + + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { + return new LevelDBTypeInfo(null, type, type.getName().getBytes(UTF_8)); + } + + private void assertBefore(byte[] key1, byte[] key2) { + assertBefore(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + private void assertBefore(String str1, String str2) { + assertTrue(String.format("%s < %s failed", str1, str2), str1.compareTo(str2) < 0); + } + + private void assertSame(byte[] key1, byte[] key2) { + assertEquals(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + public static class NoNaturalIndex { + + public String id; + + } + + public static class NoNaturalIndex2 { + + @KVIndex("id") + public String id; + + } + + public static class DuplicateIndex { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex("id") + public String id2; + + } + + public static class EmptyIndexName { + + @KVIndex("") + public String id; + + } + + public static class IllegalIndexName { + + @KVIndex("__invalid") + public String id; + + } + + public static class IllegalIndexMethod { + + @KVIndex("id") + public String id(boolean illegalParam) { + return null; + } + + } + +} diff --git a/common/kvstore/src/test/resources/log4j.properties b/common/kvstore/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/common/kvstore/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index b7c985ace69cf..b9aab5a3712c4 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,6 +34,7 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files +# - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef6656222455..3e10b9eee4e24 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus(numPartitions: Int) { + + // All accesses to the following state must be guarded with `this.synchronized`. + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + private[this] val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapId) == null) { + _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapId) = status + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle outputs. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + cachedSerializedBroadcast.destroy() + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val mapStatuses: Map[Int, Array[MapStatus]] - - /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) @@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** - * Return statistics about all of the outputs for a given shuffle. + * Deletes map output status information for the specified shuffle stage. */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } + def unregisterShuffle(shuffleId: Int): Unit - if (fetchedStatuses == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatuses != null) { - return fetchedStatuses - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - return statuses - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() - } - } - } - - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - } - - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) - if (changeEpoch) { - incrementEpoch() - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + def unregisterShuffle(shuffleId: Int) { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch } } - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) + // This method is only called in local-mode. + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } - } - bytes + case None => + Seq.empty } } @@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = + + val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1a2443f7ee78d..b2a26c51d4de1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -195,6 +195,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _conf: SparkConf = _ private var _eventLogDir: Option[URI] = None private var _eventLogCodec: Option[String] = None + private var _listenerBus: LiveListenerBus = _ private var _env: SparkEnv = _ private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ @@ -247,7 +248,7 @@ class SparkContext(config: SparkConf) extends Logging { def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events - private[spark] val listenerBus = new LiveListenerBus(this) + private[spark] def listenerBus: LiveListenerBus = _listenerBus // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( @@ -423,6 +424,8 @@ class SparkContext(config: SparkConf) extends Logging { if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") + _listenerBus = new LiveListenerBus(_conf) + // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. _jobProgressListener = new JobProgressListener(_conf) @@ -2388,7 +2391,7 @@ class SparkContext(config: SparkConf) extends Logging { } } - listenerBus.start() + listenerBus.start(this, _env.metricsSystem) _listenerBusStarted = true } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 5100a17006e24..3d9a14c51618b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -187,6 +187,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull + repositories = Option(repositories) + .orElse(sparkProperties.get("spark.jars.repositories")).orNull deployMode = Option(deployMode) .orElse(sparkProperties.get("spark.submit.deployMode")) .orElse(env.get("DEPLOY_MODE")) @@ -556,8 +558,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --verbose, -v Print additional debug output. | --version, Print the version of current Spark. | - | Spark standalone with cluster deploy mode only: - | --driver-cores NUM Cores for driver (Default: 1). + | Cluster deploy mode only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. @@ -572,8 +575,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | or all available cores on the worker in standalone mode) | | YARN-only: - | --driver-cores NUM Number of cores used by the driver, only in cluster mode - | (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | If dynamic allocation is enabled, the initial number of diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5b396687dd11a..19e7eb086f413 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -322,8 +322,14 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4ad04b04c312d..7827e6760f355 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -158,6 +158,12 @@ package object config { .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) + private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED = + ConfigBuilder("spark.scheduler.listenerbus.metrics.maxListenerClassesTimed") + .internal() + .intConf + .createWithDefault(128) + // This property sets the root namespace for metrics reporting private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace") .stringConf diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d85..2985c90119468 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1118,9 +1118,9 @@ abstract class RDD[T: ClassTag]( /** * Aggregates the elements of this RDD in a multi-level tree pattern. + * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]]. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] */ def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, @@ -1134,7 +1134,7 @@ abstract class RDD[T: ClassTag]( val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce @@ -1146,9 +1146,10 @@ abstract class RDD[T: ClassTag]( val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values } - partiallyAggregated.reduce(cleanCombOp) + val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + partiallyAggregated.fold(copiedZeroValue)(cleanCombOp) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ab2255f8a6654..932e6c138e1c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -328,25 +328,14 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") @@ -1217,7 +1206,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1234,16 +1224,14 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1343,7 +1331,6 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } @@ -1393,17 +1380,7 @@ class DAGScheduler( if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } + mapOutputTracker.removeOutputsOnExecutor(execId) clearCacheLocs() } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 801dfaa62306a..f0887e090b956 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -20,10 +20,16 @@ package org.apache.spark.scheduler import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.mutable import scala.util.DynamicVariable -import org.apache.spark.SparkContext +import com.codahale.metrics.{Counter, Gauge, MetricRegistry, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.util.Utils /** @@ -33,15 +39,20 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { +private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { + self => import LiveListenerBus._ + private var sparkContext: SparkContext = _ + // Cap the capacity of the event queue so we get an explicit error (rather than // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( - sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + private val eventQueue = + new LinkedBlockingQueue[SparkListenerEvent](conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + + private[spark] val metrics = new LiveListenerBusMetrics(conf, eventQueue) // Indicate if `start()` is called private val started = new AtomicBoolean(false) @@ -67,6 +78,7 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { LiveListenerBus.withinListenerThread.withValue(true) { + val timer = metrics.eventProcessingTime while (true) { eventLock.acquire() self.synchronized { @@ -82,7 +94,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa } return } - postToAll(event) + val timerContext = timer.time() + try { + postToAll(event) + } finally { + timerContext.stop() + } } finally { self.synchronized { processingEvent = false @@ -93,6 +110,10 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa } } + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + } + /** * Start sending events to attached listeners. * @@ -100,9 +121,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa * listens for any additional events asynchronously while the listener bus is still running. * This should only be called once. * + * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(): Unit = { + def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = { if (started.compareAndSet(false, true)) { + sparkContext = sc + metricsSystem.registerSource(metrics) listenerThread.start() } else { throw new IllegalStateException(s"$name already started!") @@ -115,12 +139,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa logError(s"$name has already stopped! Dropping event $event") return } + metrics.numEventsPosted.inc() val eventAdded = eventQueue.offer(event) if (eventAdded) { eventLock.release() } else { onDropEvent(event) - droppedEventsCounter.incrementAndGet() } val droppedEvents = droppedEventsCounter.get @@ -200,6 +224,8 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa * Note: `onDropEvent` can be called in any thread. */ def onDropEvent(event: SparkListenerEvent): Unit = { + metrics.numDroppedEvents.inc() + droppedEventsCounter.incrementAndGet() if (logDroppedEvent.compareAndSet(false, true)) { // Only log the following message once to avoid duplicated annoying logs. logError("Dropping SparkListenerEvent because no remaining room in event queue. " + @@ -217,3 +243,64 @@ private[spark] object LiveListenerBus { val name = "SparkListenerBus" } +private[spark] class LiveListenerBusMetrics( + conf: SparkConf, + queue: LinkedBlockingQueue[_]) + extends Source with Logging { + + override val sourceName: String = "LiveListenerBus" + override val metricRegistry: MetricRegistry = new MetricRegistry + + /** + * The total number of events posted to the LiveListenerBus. This is a count of the total number + * of events which have been produced by the application and sent to the listener bus, NOT a + * count of the number of events which have been processed and delivered to listeners (or dropped + * without being delivered). + */ + val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted")) + + /** + * The total number of events that were dropped without being delivered to listeners. + */ + val numDroppedEvents: Counter = metricRegistry.counter(MetricRegistry.name("numEventsDropped")) + + /** + * The amount of time taken to post a single event to all listeners. + */ + val eventProcessingTime: Timer = metricRegistry.timer(MetricRegistry.name("eventProcessingTime")) + + /** + * The number of messages waiting in the queue. + */ + val queueSize: Gauge[Int] = { + metricRegistry.register(MetricRegistry.name("queueSize"), new Gauge[Int]{ + override def getValue: Int = queue.size() + }) + } + + // Guarded by synchronization. + private val perListenerClassTimers = mutable.Map[String, Timer]() + + /** + * Returns a timer tracking the processing time of the given listener class. + * events processed by that listener. This method is thread-safe. + */ + def getTimerForListenerClass(cls: Class[_ <: SparkListenerInterface]): Option[Timer] = { + synchronized { + val className = cls.getName + val maxTimed = conf.get(LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED) + perListenerClassTimers.get(className).orElse { + if (perListenerClassTimers.size == maxTimed) { + logError(s"Not measuring processing time for listener class $className because a " + + s"maximum of $maxTimed listener classes are already timed.") + None + } else { + perListenerClassTimers(className) = + metricRegistry.timer(MetricRegistry.name("listenerProcessingTime", className)) + perListenerClassTimers.get(className) + } + } + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270c..05f650fbf5df9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage( /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1b6bc9139f9c9..629cfc7c7a8ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO @@ -240,8 +240,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // 2. The task set manager has been created but no tasks has been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") + taskIdToExecutorId.get(tid).foreach(execId => + backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index fa5ad4e8d81e1..76a56298aaebc 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal +import com.codahale.metrics.Timer + import org.apache.spark.internal.Logging /** @@ -30,14 +32,22 @@ import org.apache.spark.internal.Logging */ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + private[this] val listenersPlusTimers = new CopyOnWriteArrayList[(L, Option[Timer])] + // Marked `private[spark]` for access in tests. - private[spark] val listeners = new CopyOnWriteArrayList[L] + private[spark] def listeners = listenersPlusTimers.asScala.map(_._1).asJava + + /** + * Returns a CodaHale metrics Timer for measuring the listener's event processing time. + * This method is intended to be overridden by subclasses. + */ + protected def getTimer(listener: L): Option[Timer] = None /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ final def addListener(listener: L): Unit = { - listeners.add(listener) + listenersPlusTimers.add((listener, getTimer(listener))) } /** @@ -45,7 +55,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * in any thread. */ final def removeListener(listener: L): Unit = { - listeners.remove(listener) + listenersPlusTimers.asScala.find(_._1 eq listener).foreach { listenerAndTimer => + listenersPlusTimers.remove(listenerAndTimer) + } } /** @@ -56,14 +68,25 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. - val iter = listeners.iterator + val iter = listenersPlusTimers.iterator while (iter.hasNext) { - val listener = iter.next() + val listenerAndMaybeTimer = iter.next() + val listener = listenerAndMaybeTimer._1 + val maybeTimer = listenerAndMaybeTimer._2 + val maybeTimerContext = if (maybeTimer.isDefined) { + maybeTimer.get.time() + } else { + null + } try { doPostEvent(listener, event) } catch { case NonFatal(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } finally { + if (maybeTimerContext != null) { + maybeTimerContext.stop() + } } } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 71bedda5ac894..bc3d23e3fbb29 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.mockito.Mockito._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException @@ -138,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } @@ -245,8 +246,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize // needs TorrentBroadcast so need a SparkContext - val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) - try { + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val rpcEnv = sc.env.rpcEnv val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) @@ -271,9 +271,6 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(1 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterShuffle(20) assert(0 == masterTracker.getNumCachedSerializedBroadcast) - - } finally { - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 622f7985ba444..3931d53b4ae0a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, @@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 6e9721c45931a..de719990cf47a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -477,6 +477,26 @@ class SparkSubmitSuite } } + test("includes jars passed through spark.jars.packages and spark.jars.repositories") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") + // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. + IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.jars.packages=my.great.lib:mylib:0.1,my.great.dep:mylib:0.1", + "--conf", s"spark.jars.repositories=$repo", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + unusedJar.toString, + "my.great.lib.MyLib", "my.great.dep.MyLib") + runSparkSubmit(args) + } + } + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8d06f5468f4f1..386c0060f9c41 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -192,6 +192,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(ser.serialize(union.partitions.head).limit() < 2000) } + test("fold") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def op: (Int, Int) => Int = (c: Int, x: Int) => c + x + val sum = rdd.fold(0)(op) + assert(sum === -1000) + } + + test("fold with op modifying first arg") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + val sum = rdd.fold(Array(0))(op) + assert(sum(0) === -1000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] @@ -218,7 +235,19 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) + assert(sum === -1000) + } + } + + test("treeAggregate with ops modifying first args") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(Array(0))(op, op, depth) + assert(sum(0) === -1000) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 31d9dd3de8acc..59d8c14d74e30 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -633,7 +633,12 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("port conflict") { val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) - assert(anotherEnv.address.port != env.address.port) + try { + assert(anotherEnv.address.port != env.address.port) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } } private def testSend(conf: SparkConf): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2b..571c6bbb4585d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4c3d0b102152c..4cae6c61118a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -25,12 +25,14 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.io._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{JsonProtocol, Utils} /** @@ -155,17 +157,18 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) - val listenerBus = new LiveListenerBus(sc) + val listenerBus = new LiveListenerBus(conf) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start() + listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) listenerBus.addListener(eventLogger) listenerBus.postToAll(applicationStart) listenerBus.postToAll(applicationEnd) + listenerBus.stop() eventLogger.stop() // Verify file contains exactly the two events logged diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 80c7e0bfee6ef..f3d0bc19675fc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,10 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.JavaConverters._ +import org.mockito.Mockito import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers @@ -36,14 +39,17 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext]) + private val mockMetricsSystem: MetricsSystem = Mockito.mock(classOf[MetricsSystem]) + test("don't call sc.stop in listener") { sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(sc.conf) bus.addListener(listener) // Starting listener bus should flush all buffered events - bus.start() + bus.start(sc, sc.env.metricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -52,35 +58,54 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("basic creation and shutdown of LiveListenerBus") { - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val conf = new SparkConf() val counter = new BasicJobCounter - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(conf) bus.addListener(counter) - // Listener bus hasn't started yet, so posting events should not increment counter + // Metrics are initially empty. + assert(bus.metrics.numEventsPosted.getCount === 0) + assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.eventProcessingTime.getCount === 0) + + // Post five events: (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + + // Five messages should be marked as received and queued, but no messages should be posted to + // listeners yet because the the listener bus hasn't been started. + assert(bus.metrics.numEventsPosted.getCount === 5) + assert(bus.metrics.queueSize.getValue === 5) assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) + Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.eventProcessingTime.getCount === 5) // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) + assert(bus.metrics.numEventsPosted.getCount === 5) + + // Make sure per-listener-class timers were created: + assert(bus.metrics.getTimerForListenerClass( + classOf[BasicJobCounter].asSubclass(classOf[SparkListenerInterface])).get.getCount == 5) // Listener bus must not be started twice intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) - bus.start() - bus.start() + val bus = new LiveListenerBus(conf) + bus.start(mockSparkContext, mockMetricsSystem) + bus.start(mockSparkContext, mockMetricsSystem) } // ... or stopped before starting intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(conf) bus.stop() } } @@ -107,12 +132,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match drained = true } } - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) val blockingListener = new BlockingListener bus.addListener(blockingListener) - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -138,6 +162,44 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(drained) } + test("metrics for dropped listener events") { + val bus = new LiveListenerBus(new SparkConf().set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 1)) + + val listenerStarted = new Semaphore(0) + val listenerWait = new Semaphore(0) + + bus.addListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + listenerStarted.release() + listenerWait.acquire() + } + }) + + bus.start(mockSparkContext, mockMetricsSystem) + + // Post a message to the listener bus and wait for processing to begin: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + listenerStarted.acquire() + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.numDroppedEvents.getCount === 0) + + // If we post an additional message then it should remain in the queue because the listener is + // busy processing the first event: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(bus.metrics.queueSize.getValue === 1) + assert(bus.metrics.numDroppedEvents.getCount === 0) + + // The queue is now full, so any additional events posted to the listener will be dropped: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(bus.metrics.queueSize.getValue === 1) + assert(bus.metrics.numDroppedEvents.getCount === 1) + + + // Allow the the remaining events to be processed so we can stop the listener bus: + listenerWait.release(2) + bus.stop() + } + test("basic creation of StageInfo") { sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo @@ -354,14 +416,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val badListener = new BadListener val jobCounter1 = new BasicJobCounter val jobCounter2 = new BasicJobCounter - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) // Propagate events to bad listener first bus.addListener(badListener) bus.addListener(jobCounter1) bus.addListener(jobCounter2) - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 21251f0b93760..cf01f79f49091 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.LocalSparkContext +import org.apache.spark.LocalSparkContext._ import org.apache.spark.SparkContext import org.apache.spark.SparkException @@ -32,9 +32,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "1m") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect()) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + intercept[SparkException](sc.parallelize(x).collect()) + } } test("kryo with resizable output buffer should succeed on large array") { @@ -42,8 +42,8 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "2m") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect() === x) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + assert(sc.parallelize(x).collect() === x) + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c100803279eaf..dd61dcd11bcda 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -100,7 +100,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0d2912ba8c5fb..9d52b488b223e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -125,7 +125,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index f6c8418ba3ac4..66dda382eb653 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ @@ -43,8 +43,7 @@ class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAn before { val conf = new SparkConf() - sc = new SparkContext("local", "test", conf) - bus = new LiveListenerBus(sc) + bus = new LiveListenerBus(conf) storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ab1de3d3dd8ad..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -47,9 +47,9 @@ commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar -curator-client-2.6.0.jar -curator-framework-2.6.0.jar -curator-recipes-2.6.0.jar +curator-client-2.7.1.jar +curator-framework-2.7.1.jar +curator-recipes-2.7.1.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar diff --git a/docs/configuration.md b/docs/configuration.md index 0771e36f80b50..f777811a93f62 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -474,10 +474,19 @@ Apart from these, the following properties are also available, and may be useful Path to an Ivy settings file to customize resolution of jars specified using spark.jars.packages instead of the built-in defaults, such as maven central. Additional repositories given by the command-line - option --repositories will also be included. Useful for allowing Spark to resolve artifacts from behind - a firewall e.g. via an in-house artifact server like Artifactory. Details on the settings file format can be + option --repositories or spark.jars.repositories will also be included. + Useful for allowing Spark to resolve artifacts from behind a firewall e.g. via an in-house + artifact server like Artifactory. Details on the settings file format can be found at http://ant.apache.org/ivy/history/latest-milestone/settings.html + + + spark.jars.repositories + + + Comma-separated list of additional remote repositories to search for the maven coordinates + given with --packages or spark.jars.packages. + spark.pyspark.driver.python diff --git a/docs/index.md b/docs/index.md index 960b968454d0e..f7b5863957ce2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,15 +26,13 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8+, Python 2.6+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). -Note that support for Java 7 was removed as of Spark 2.2.0. +Note that support for Java 7, Python 2.6 and old Hadoop versions before 2.6.5 were removed as of Spark 2.2.0. -Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and support for -Scala 2.10 and versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0, and may be -removed in Spark 2.2.0. +Note that support for Scala 2.10 is deprecated as of Spark 2.1.0, and may be removed in Spark 2.3.0. # Running the Examples and Shell diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index c1344ad99a7d2..ec130c1db8f5f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -156,7 +156,7 @@ passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `Mesos If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. +For more information about these configurations please refer to the configurations [doc](configurations.html#deploy). From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the @@ -382,8 +382,9 @@ See the [configuration page](configuration.html) for information on Spark config (none) Set the Mesos labels to add to each task. Labels are free-form key-value pairs. - Key-value pairs should be separated by a colon, and commas used to list more than one. - Ex. key:value,key2:value2. + Key-value pairs should be separated by a colon, and commas used to + list more than one. If your label includes a colon or comma, you + can escape it with a backslash. Ex. key:value,key2:a\:b. @@ -468,6 +469,15 @@ See the [configuration page](configuration.html) for information on Spark config If unset it will point to Spark's internal web UI. + + spark.mesos.driver.labels + (none) + + Mesos labels to add to the driver. See spark.mesos.task.labels + for formatting information. + + + spark.mesos.driverEnv.[EnvironmentVariableName] (none) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index 92c296a9e6bd3..386066a85749f 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -91,7 +91,9 @@ The new Kafka consumer API will pre-fetch messages into buffers. Therefore it i In most cases, you should use `LocationStrategies.PreferConsistent` as shown above. This will distribute partitions evenly across available executors. If your executors are on the same hosts as your Kafka brokers, use `PreferBrokers`, which will prefer to schedule partitions on the Kafka leader for that partition. Finally, if you have a significant skew in load among partitions, use `PreferFixed`. This allows you to specify an explicit mapping of partitions to hosts (any unspecified partitions will use a consistent location). -The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` +The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity`. + +If you would like to disable the caching for Kafka consumers, you can set `spark.streaming.kafka.consumer.cache.enabled` to `false`. Disabling the cache may be needed to workaround the problem described in SPARK-19185. This property may be removed in later versions of Spark, once SPARK-19185 is resolved. The cache is keyed by topicpartition and group.id, so use a **separate** `group.id` for each call to `createDirectStream`. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 6a25c9939c264..9b9177d44145f 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1056,7 +1056,7 @@ Some of them are as follows. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). -- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy.count()` which returns a streaming Dataset containing a running count. +- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy().count()` which returns a streaming Dataset containing a running count. - `foreach()` - Instead use `ds.writeStream.foreach(...)` (see next section). diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index 7bb9993b84168..00033b5730a3d 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -40,7 +40,7 @@ public static void main(String[] args) { JavaRDD parsedData = data.map(line -> { String[] parts = line.split(" "); double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) { + for (int i = 1; i < parts.length; i++) { v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 6d6983c4bd419..9a4a1cf32a480 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -213,8 +213,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( val fo = currentOffsets(tp) OffsetRange(tp.topic, tp.partition, fo, uo) } - val rdd = new KafkaRDD[K, V]( - context.sparkContext, executorKafkaParams, offsetRanges.toArray, getPreferredHosts, true) + val useConsumerCache = context.conf.getBoolean("spark.streaming.kafka.consumer.cache.enabled", + true) + val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray, + getPreferredHosts, useConsumerCache) // Report the record number and metadata of this batch interval to InputInfoTracker. val description = offsetRanges.filter { offsetRange => @@ -316,7 +318,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( b.map(OffsetRange(_)), getPreferredHosts, // during restore, it's possible same partition will be consumed from multiple - // threads, so dont use cache + // threads, so do not use cache. false ) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 4ca062c0b5adf..b6909b3386b71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = instance.wordVectors.getVectors val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) sparkSession.createDataFrame(dataSeq) - .repartition(calculateNumberOfPartitions) + .repartition(numPartitions) .write .parquet(dataPath) } + } - def calculateNumberOfPartitions(): Int = { - val floatSize = 4 + private[feature] + object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ + def calculateNumberOfPartitions( + bufferSizeInBytes: Long, + numWords: Int, + vectorSize: Int): Int = { + val floatSize = 4L // Use Long to help avoid overflow val averageWordSize = 15 - // [SPARK-11994] - We want to partition the model in partitions smaller than - // spark.kryoserializer.buffer.max - val bufferSizeInBytes = Utils.byteStringAsBytes( - sc.conf.get("spark.kryoserializer.buffer.max", "64m")) // Calculate the approximate size of the model. // Assuming an average word size of 15 bytes, the formula is: // (floatSize * vectorSize + 15) * numWords - val numWords = instance.wordVectors.wordIndex.size - val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords - ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt + val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords + val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1 + require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " + + s"partitions to save this model, which is too large. Try increasing " + + s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.") + numPartitions.toInt } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala new file mode 100644 index 0000000000000..403c28ff732f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * A parent trait for aggregators used in fitting MLlib models. This parent trait implements + * some of the common code shared between concrete instances of aggregators. Subclasses of this + * aggregator need only implement the `add` method. + * + * @tparam Datum The type of the instances added to the aggregator to update the loss and gradient. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]]. Classes that subclass this + * type need to use this parameter to specify the concrete type of the aggregator. + */ +private[ml] trait DifferentiableLossAggregator[ + Datum, + Agg <: DifferentiableLossAggregator[Datum, Agg]] extends Serializable { + + self: Agg => // enforce classes that extend this to be the same type as `Agg` + + protected var weightSum: Double = 0.0 + protected var lossSum: Double = 0.0 + + /** The dimension of the gradient array. */ + protected val dim: Int + + /** Array of gradient values that are mutated when new instances are added to the aggregator. */ + protected lazy val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** Add a single data point to this aggregator. */ + def add(instance: Datum): Agg + + /** Merge two aggregators. The `this` object will be modified in place and returned. */ + def merge(other: Agg): Agg = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"${getClass.getSimpleName}. Expecting $dim but got ${other.dim}.") + + if (other.weightSum != 0) { + weightSum += other.weightSum + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + /** The current weighted averaged gradient. */ + def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + val result = Vectors.dense(gradientSumArray.clone()) + BLAS.scal(1.0 / weightSum, result) + result + } + + /** Weighted count of instances in this aggregator. */ + def weight: Double = weightSum + + /** The current loss value of this aggregator. */ + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + lossSum / weightSum + } + +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala new file mode 100644 index 0000000000000..1994b0e40e520 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in an online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the coefficients in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * When training with intercept enabled, + * The objective function in the scaled space is given by + * + *
+ * $$ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * $$ + *
+ * + * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, + * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. + * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead + * of the respective means. + * + * This can be rewritten as + * + *
+ * $$ + * \begin{align} + * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 \\ + * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * \end{align} + * $$ + *
+ * + * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is + * + *
+ * $$ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * $$ + *
+ * + * and diff is + * + *
+ * $$ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * $$ + *
+ * + * Note that the effective coefficients and offset don't depend on training dataset, + * so they can be precomputed. + * + * Now, the first derivative of the objective function in scaled space is + * + *
+ * $$ + * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * $$ + *
+ * + * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * + * + *
+ * $$ + * \begin{align} + * \frac{\partial L}{\partial w_i} &= + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * \end{align} + * $$ + *
+ * + * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * + *
+ * $$ + * \begin{align} + * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) + * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ + * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ + * &= 0 + * \end{align} + * $$ + *
+ * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * + *
+ * $$ + * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * $$ + *
+ * + * @note The constructor is curried, since the cost function will repeatedly create new versions + * of this class for different coefficient vectors. + * + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param fitIntercept Whether to fit an intercept term. + * @param bcFeaturesStd The broadcast standard deviation values of the features. + * @param bcFeaturesMean The broadcast mean values of the features. + * @param bcCoefficients The broadcast coefficients corresponding to the features. + */ +private[ml] class LeastSquaresAggregator( + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + bcFeaturesStd: Broadcast[Array[Double]], + bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[Instance, LeastSquaresAggregator] { + require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " + + s"deviation to be positive.") + + private val numFeatures = bcFeaturesStd.value.length + protected override val dim: Int = numFeatures + // make transient so we do not serialize between aggregation stages + @transient private lazy val featuresStd = bcFeaturesStd.value + @transient private lazy val effectiveCoefAndOffset = { + val coefficientsArray = bcCoefficients.value.toArray.clone() + val featuresMean = bcFeaturesMean.value + var sum = 0.0 + var i = 0 + val len = coefficientsArray.length + while (i < len) { + if (featuresStd(i) != 0.0) { + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) + } else { + coefficientsArray(i) = 0.0 + } + i += 1 + } + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (Vectors.dense(coefficientsArray), offset) + } + // do not use tuple assignment above because it will circumvent the @transient tag + @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 + @transient private lazy val offset = effectiveCoefAndOffset._2 + + /** + * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This LeastSquaresAggregator object. + */ + def add(instance: Instance): LeastSquaresAggregator = { + instance match { case Instance(label, weight, features) => + require(numFeatures == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $numFeatures but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + + val diff = BLAS.dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + val localFeaturesStd = featuresStd + features.foreachActive { (index, value) => + val fStd = localFeaturesStd(index) + if (fStd != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / fStd + } + } + lossSum += weight * diff * diff / 2.0 + } + weightSum += weight + this + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala new file mode 100644 index 0000000000000..118c0ebfa513e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import breeze.optimize.DiffFunction + +/** + * A Breeze diff function which represents a cost function for differentiable regularization + * of parameters. e.g. L2 regularization: 1 / 2 regParam * beta dot beta + * + * @tparam T The type of the coefficients being regularized. + */ +private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] { + + /** Magnitude of the regularization penalty. */ + def regParam: Double + +} + +/** + * A Breeze diff function for computing the L2 regularized loss and gradient of an array of + * coefficients. + * + * @param regParam The magnitude of the regularization. + * @param shouldApply A function (Int => Boolean) indicating whether a given index should have + * regularization applied to it. + * @param featuresStd Option indicating whether the regularization should be scaled by the standard + * deviation of the features. + */ +private[ml] class L2Regularization( + val regParam: Double, + shouldApply: Int => Boolean, + featuresStd: Option[Array[Double]]) extends DifferentiableRegularization[Array[Double]] { + + override def calculate(coefficients: Array[Double]): (Double, Array[Double]) = { + var sum = 0.0 + val gradient = new Array[Double](coefficients.length) + coefficients.indices.filter(shouldApply).foreach { j => + val coef = coefficients(j) + featuresStd match { + case Some(stds) => + val std = stds(j) + if (std != 0.0) { + val temp = coef / (std * std) + sum += coef * temp + gradient(j) = regParam * temp + } else { + 0.0 + } + case None => + sum += coef * coef + gradient(j) = coef * regParam + } + } + (0.5 * sum * regParam, gradient) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala new file mode 100644 index 0000000000000..3b1618eb0b6fe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import scala.reflect.ClassTag + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.DiffFunction + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator +import org.apache.spark.rdd.RDD + +/** + * This class computes the gradient and loss of a differentiable loss function by mapping a + * [[DifferentiableLossAggregator]] over an [[RDD]] of [[Instance]]s. The loss function is the + * sum of the loss computed on a single instance across all points in the RDD. Therefore, the actual + * analytical form of the loss function is specified by the aggregator, which computes each points + * contribution to the overall loss. + * + * A differentiable regularization component can also be added by providing a + * [[DifferentiableRegularization]] loss function. + * + * @param instances + * @param getAggregator A function which gets a new loss aggregator in every tree aggregate step. + * @param regularization An option representing the regularization loss function to apply to the + * coefficients. + * @param aggregationDepth The aggregation depth of the tree aggregation step. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]], representing the concrete type + * of the aggregator. + */ +private[ml] class RDDLossFunction[ + T: ClassTag, + Agg <: DifferentiableLossAggregator[T, Agg]: ClassTag]( + instances: RDD[T], + getAggregator: (Broadcast[Vector] => Agg), + regularization: Option[DifferentiableRegularization[Array[Double]]], + aggregationDepth: Int = 2) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val bcCoefficients = instances.context.broadcast(Vectors.fromBreeze(coefficients)) + val thisAgg = getAggregator(bcCoefficients) + val seqOp = (agg: Agg, x: T) => agg.add(x) + val combOp = (agg1: Agg, agg2: Agg) => agg1.merge(agg2) + val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth) + val gradient = newAgg.gradient + val regLoss = regularization.map { regFun => + val (regLoss, regGradient) = regFun.calculate(coefficients.data) + BLAS.axpy(1.0, Vectors.dense(regGradient), gradient) + regLoss + }.getOrElse(0.0) + bcCoefficients.destroy(blocking = false) + (newAgg.loss + regLoss, gradient.asBreeze.toDenseVector) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index eaad54985229e..db5ac4f14bd3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,19 +20,20 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator +import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -319,8 +320,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), - $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth)) + val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + val regularization = if (effectiveL2RegParam != 0.0) { + val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures + Some(new L2Regularization(effectiveL2RegParam, shouldApply, + if ($(standardization)) None else Some(featuresStd))) + } else { + None + } + val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, + $(aggregationDepth)) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -793,312 +803,3 @@ class LinearRegressionSummary private[regression] ( } -/** - * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, - * as used in linear regression for samples in sparse or dense vector in an online fashion. - * - * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of - * the corresponding joint dataset. - * - * For improving the convergence rate during the optimization process, and also preventing against - * features with very large variances exerting an overly large influence during model training, - * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce - * the condition number, and then trains the model in scaled space but returns the coefficients in - * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf - * - * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache - * the standardized dataset since it will create a lot of overhead. As a result, we perform the - * scaling implicitly when we compute the objective function. The following is the mathematical - * derivation. - * - * Note that we don't deal with intercept by adding bias here, because the intercept - * can be computed using closed form after the coefficients are converged. - * See this discussion for detail. - * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - * - * When training with intercept enabled, - * The objective function in the scaled space is given by - * - *
- * $$ - * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, - * $$ - *
- * - * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, - * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. - * - * If we fitting the intercept disabled (that is forced through 0.0), - * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead - * of the respective means. - * - * This can be rewritten as - * - *
- * $$ - * \begin{align} - * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} - * + \bar{y} / \hat{y}||^2 \\ - * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 - * \end{align} - * $$ - *
- * - * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is - * - *
- * $$ - * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. - * $$ - *
- * - * and diff is - * - *
- * $$ - * \sum_i w_i^\prime x_i - y / \hat{y} + offset - * $$ - *
- * - * Note that the effective coefficients and offset don't depend on training dataset, - * so they can be precomputed. - * - * Now, the first derivative of the objective function in scaled space is - * - *
- * $$ - * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} - * $$ - *
- * - * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not - * an ideal formula when the training dataset is sparse format. - * - * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms - * in the end by keeping the sum of diff. The first derivative of total - * objective function from all the samples is - * - * - *
- * $$ - * \begin{align} - * \frac{\partial L}{\partial w_i} &= - * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) - * \end{align} - * $$ - *
- * - * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ - * - * A simple math can show that diffSum is actually zero, so we don't even - * need to add the correction terms in the end. From the definition of diff, - * - *
- * $$ - * \begin{align} - * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) - * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ - * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ - * &= 0 - * \end{align} - * $$ - *
- * - * As a result, the first derivative of the total objective function only depends on - * the training dataset, which can be easily computed in distributed fashion, and is - * sparse format friendly. - * - *
- * $$ - * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - * $$ - *
- * - * @param bcCoefficients The broadcast coefficients corresponding to the features. - * @param labelStd The standard deviation value of the label. - * @param labelMean The mean value of the label. - * @param fitIntercept Whether to fit an intercept term. - * @param bcFeaturesStd The broadcast standard deviation values of the features. - * @param bcFeaturesMean The broadcast mean values of the features. - */ -private class LeastSquaresAggregator( - bcCoefficients: Broadcast[Vector], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]]) extends Serializable { - - private var totalCnt: Long = 0L - private var weightSum: Double = 0.0 - private var lossSum = 0.0 - - private val dim = bcCoefficients.value.size - // make transient so we do not serialize between aggregation stages - @transient private lazy val featuresStd = bcFeaturesStd.value - @transient private lazy val effectiveCoefAndOffset = { - val coefficientsArray = bcCoefficients.value.toArray.clone() - val featuresMean = bcFeaturesMean.value - var sum = 0.0 - var i = 0 - val len = coefficientsArray.length - while (i < len) { - if (featuresStd(i) != 0.0) { - coefficientsArray(i) /= featuresStd(i) - sum += coefficientsArray(i) * featuresMean(i) - } else { - coefficientsArray(i) = 0.0 - } - i += 1 - } - val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 - (Vectors.dense(coefficientsArray), offset) - } - // do not use tuple assignment above because it will circumvent the @transient tag - @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 - @transient private lazy val offset = effectiveCoefAndOffset._2 - - private lazy val gradientSumArray = Array.ofDim[Double](dim) - - /** - * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * - * @param instance The instance of data point to be added. - * @return This LeastSquaresAggregator object. - */ - def add(instance: Instance): this.type = { - instance match { case Instance(label, weight, features) => - - if (weight == 0.0) return this - - val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset - - if (diff != 0) { - val localGradientSumArray = gradientSumArray - val localFeaturesStd = featuresStd - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += weight * diff * value / localFeaturesStd(index) - } - } - lossSum += weight * diff * diff / 2.0 - } - - totalCnt += 1 - weightSum += weight - this - } - } - - /** - * Merge another LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * (Note that it's in place merging; as a result, `this` object will be modified.) - * - * @param other The other LeastSquaresAggregator to be merged. - * @return This LeastSquaresAggregator object. - */ - def merge(other: LeastSquaresAggregator): this.type = { - - if (other.weightSum != 0) { - totalCnt += other.totalCnt - weightSum += other.weightSum - lossSum += other.lossSum - - var i = 0 - val localThisGradientSumArray = this.gradientSumArray - val localOtherGradientSumArray = other.gradientSumArray - while (i < dim) { - localThisGradientSumArray(i) += localOtherGradientSumArray(i) - i += 1 - } - } - this - } - - def count: Long = totalCnt - - def loss: Double = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - lossSum / weightSum - } - - def gradient: Vector = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / weightSum, result) - result - } -} - -/** - * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. - * It returns the loss and gradient with L2 regularization at a particular point (coefficients). - * It's used in Breeze's convex optimization routines. - */ -private class LeastSquaresCostFun( - instances: RDD[Instance], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - standardization: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]], - effectiveL2regParam: Double, - aggregationDepth: Int) extends DiffFunction[BDV[Double]] { - - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { - val coeffs = Vectors.fromBreeze(coefficients) - val bcCoeffs = instances.context.broadcast(coeffs) - val localFeaturesStd = bcFeaturesStd.value - - val leastSquaresAggregator = { - val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance) - val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2) - - instances.treeAggregate( - new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd, - bcFeaturesMean))(seqOp, combOp, aggregationDepth) - } - - val totalGradientArray = leastSquaresAggregator.gradient.toArray - bcCoeffs.destroy(blocking = false) - - val regVal = if (effectiveL2regParam == 0.0) { - 0.0 - } else { - var sum = 0.0 - coeffs.foreachActive { (index, value) => - // The following code will compute the loss of the regularization; also - // the gradient of the regularization, and add back to totalGradientArray. - sum += { - if (standardization) { - totalGradientArray(index) += effectiveL2regParam * value - value * value - } else { - if (localFeaturesStd(index) != 0.0) { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - val temp = value / (localFeaturesStd(index) * localFeaturesStd(index)) - totalGradientArray(index) += effectiveL2regParam * temp - value * temp - } else { - 0.0 - } - } - } - } - 0.5 * effectiveL2regParam * sum - } - - (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index df2a9c0dd5094..3ad08c46d204d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -85,7 +85,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data + .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + bcCentersWithNorm.destroy(blocking = false) + cost } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 663f63c25a940..4ab420058f33d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -320,6 +320,7 @@ class LocalLDAModel private[spark] ( docBound }.sum() + ElogbetaBc.destroy(blocking = false) // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] @@ -372,7 +373,6 @@ class LocalLDAModel private[spark] ( */ private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) - val expElogbetaBc = sc.broadcast(expElogbeta) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k @@ -383,7 +383,7 @@ class LocalLDAModel private[spark] ( } else { val (gamma, _, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, - expElogbetaBc.value, + expElogbeta, docConcentrationBrz, gammaShape, k) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 07a67a9e719db..593cdd602fafc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -246,6 +246,7 @@ object GradientDescent extends Logging { // c: (grad, loss, count) (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) + bcWeights.destroy(blocking = false) if (miniBatchSize > 0) { /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a6a1c2b4f32bd..6183606a7b2ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } + test("Word2Vec read/write numPartitions calculation") { + val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5) + assert(smallModelNumPartitions === 1) + val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000) + assert(largeModelNumPartitions > 1) + } + test("Word2Vec read/write") { val t = new Word2Vec() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala new file mode 100644 index 0000000000000..7a4faeb1c10bf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ + +class DifferentiableLossAggregatorSuite extends SparkFunSuite { + + import DifferentiableLossAggregatorSuite.TestAggregator + + private val instances1 = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + private val instances2 = Seq( + Instance(0.2, 0.4, Vectors.dense(0.8, 2.5)), + Instance(0.8, 0.9, Vectors.dense(2.0, 1.3)), + Instance(1.5, 0.2, Vectors.dense(3.0, 0.2)) + ) + + private def assertEqual[T, Agg <: DifferentiableLossAggregator[T, Agg]]( + agg1: DifferentiableLossAggregator[T, Agg], + agg2: DifferentiableLossAggregator[T, Agg]): Unit = { + assert(agg1.weight === agg2.weight) + assert(agg1.loss === agg2.loss) + assert(agg1.gradient === agg2.gradient) + } + + test("empty aggregator") { + val numFeatures = 5 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + withClue("cannot get loss for empty aggregator") { + intercept[IllegalArgumentException] { + agg.loss + } + } + withClue("cannot get gradient for empty aggregator") { + intercept[IllegalArgumentException] { + agg.gradient + } + } + } + + test("aggregator initialization") { + val numFeatures = 3 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + agg.add(Instance(1.0, 0.3, Vectors.dense(Array.fill(numFeatures)(1.0)))) + assert(agg.gradient.size === 3) + assert(agg.weight === 0.3) + } + + test("merge aggregators") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg1 = new TestAggregator(2)(coefficients) + val agg2 = new TestAggregator(2)(coefficients) + val aggBadDim = new TestAggregator(1)(Vectors.dense(0.5)) + aggBadDim.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + instances1.foreach(agg1.add) + + // merge incompatible aggregators + withClue("cannot merge aggregators with different dimensions") { + intercept[IllegalArgumentException] { + agg1.merge(aggBadDim) + } + } + + // merge empty other + val mergedEmptyOther = agg1.merge(agg2) + assertEqual(mergedEmptyOther, agg1) + assert(mergedEmptyOther === agg1) + + // merge empty this + val agg3 = new TestAggregator(2)(coefficients) + val mergedEmptyThis = agg3.merge(agg1) + assertEqual(mergedEmptyThis, agg1) + assert(mergedEmptyThis !== agg1) + + instances2.foreach(agg2.add) + val (loss1, weight1, grad1) = (agg1.loss, agg1.weight, agg1.gradient) + val (loss2, weight2, grad2) = (agg2.loss, agg2.weight, agg2.gradient) + val merged = agg1.merge(agg2) + + // check pointers are equal + assert(merged === agg1) + + // loss should be weighted average of the two individual losses + assert(merged.loss === (loss1 * weight1 + loss2 * weight2) / (weight1 + weight2)) + assert(merged.weight === weight1 + weight2) + + // gradient should be weighted average of individual gradients + val addedGradients = Vectors.dense(grad1.toArray.clone()) + BLAS.scal(weight1, addedGradients) + BLAS.axpy(weight2, grad2, addedGradients) + BLAS.scal(1 / (weight1 + weight2), addedGradients) + assert(merged.gradient === addedGradients) + } + + test("loss, gradient, weight") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg = new TestAggregator(2)(coefficients) + instances1.foreach(agg.add) + val errors = instances1.map { case Instance(label, _, features) => + label - BLAS.dot(features, coefficients) + } + val expectedLoss = errors.zip(instances1).map { case (error: Double, instance: Instance) => + instance.weight * error * error / 2.0 + } + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances1).foreach { case (error, instance) => + BLAS.axpy(instance.weight * error, instance.features, expectedGradient) + } + BLAS.scal(1.0 / agg.weight, expectedGradient) + val weightSum = instances1.map(_.weight).sum + + assert(agg.weight ~== weightSum relTol 1e-5) + assert(agg.loss ~== expectedLoss.sum / weightSum relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } +} + +object DifferentiableLossAggregatorSuite { + /** + * Dummy aggregator that represents least squares cost with no intercept. + */ + class TestAggregator(numFeatures: Int)(coefficients: Vector) + extends DifferentiableLossAggregator[Instance, TestAggregator] { + + protected override val dim: Int = numFeatures + + override def add(instance: Instance): TestAggregator = { + val error = instance.label - BLAS.dot(coefficients, instance.features) + weightSum += instance.weight + lossSum += instance.weight * error * error / 2.0 + (0 until dim).foreach { j => + gradientSumArray(j) += instance.weight * error * instance.features(j) + } + this + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala new file mode 100644 index 0000000000000..d1cb0d380e7a5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantLabel: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantLabel = Array( + Instance(1.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(1.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + } + + /** Get feature and label summarizers for provided data. */ + def getSummarizers( + instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.aggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer + )(seqOp, combOp) + } + + /** Get summary statistics for some data and create a new LeastSquaresAggregator. */ + def getNewAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): LeastSquaresAggregator = { + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val featuresMean = featuresSummarizer.mean + val bcFeaturesMean = spark.sparkContext.broadcast(featuresMean.toArray) + val bcCoefficients = spark.sparkContext.broadcast(coefficients) + new LeastSquaresAggregator(yStd, yMean, fitIntercept, bcFeaturesStd, + bcFeaturesMean)(bcCoefficients) + } + + test("check sizes") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggIntercept = getNewAggregator(instances, coefficients, fitIntercept = true) + val aggNoIntercept = getNewAggregator(instances, coefficients, fitIntercept = false) + instances.foreach(aggIntercept.add) + instances.foreach(aggNoIntercept.add) + + // least squares agg does not include intercept in its gradient array + assert(aggIntercept.gradient.size === 2) + assert(aggNoIntercept.gradient.size === 2) + } + + test("check correctness") { + /* + Check that the aggregator computes loss/gradient for: + 0.5 * sum_i=1^N ([sum_j=1^D beta_j * ((x_j - x_j,bar) / sigma_j)] - ((y - ybar) / sigma_y))^2 + */ + val coefficients = Vectors.dense(1.0, 2.0) + val numFeatures = coefficients.size + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + + val agg = getNewAggregator(instances, coefficients, fitIntercept = true) + instances.foreach(agg.add) + + // compute (y - pred) analytically + val errors = instances.map { case Instance(l, w, f) => + val scaledFeatures = (0 until numFeatures).map { j => + (f.toArray(j) - featuresMean(j)) / featuresStd(j) + }.toArray + val scaledLabel = (l - yMean) / yStd + BLAS.dot(coefficients, Vectors.dense(scaledFeatures)) - scaledLabel + } + + // compute expected loss sum analytically + val expectedLoss = errors.zip(instances).map { case (error, instance) => + instance.weight * error * error / 2.0 + } + + // compute gradient analytically from instances + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances).foreach { case (error, instance) => + val scaledFeatures = (0 until numFeatures).map { j => + instance.weight * instance.features.toArray(j) / featuresStd(j) + }.toArray + BLAS.axpy(error, Vectors.dense(scaledFeatures), expectedGradient) + } + + val weightSum = instances.map(_.weight).sum + BLAS.scal(1.0 / weightSum, expectedGradient) + assert(agg.loss ~== (expectedLoss.sum / weightSum) relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } + + test("check with zero standard deviation") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefficients, + fitIntercept = true) + instances.foreach(aggConstantFeature.add) + // constant features should not affect gradient + assert(aggConstantFeature.gradient(0) === 0.0) + + withClue("LeastSquaresAggregator does not support zero standard deviation of the label") { + intercept[IllegalArgumentException] { + getNewAggregator(instancesConstantLabel, coefficients, fitIntercept = true) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala new file mode 100644 index 0000000000000..0794417a8d4bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite + +class DifferentiableRegularizationSuite extends SparkFunSuite { + + test("L2 regularization") { + val shouldApply = (_: Int) => true + val regParam = 0.3 + val coefficients = Array(1.0, 3.0, -2.0) + val numFeatures = coefficients.size + + // check without features standard + val regFun = new L2Regularization(regParam, shouldApply, None) + val (loss, grad) = regFun.calculate(coefficients) + assert(loss === 0.5 * regParam * coefficients.map(x => x * x).sum) + assert(grad === coefficients.map(_ * regParam)) + + // check with features standard + val featuresStd = Array(0.1, 1.1, 0.5) + val regFunStd = new L2Regularization(regParam, shouldApply, Some(featuresStd)) + val (lossStd, gradStd) = regFunStd.calculate(coefficients) + val expectedLossStd = 0.5 * regParam * (0 until numFeatures).map { j => + coefficients(j) * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.sum + val expectedGradientStd = (0 until numFeatures).map { j => + regParam * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.toArray + assert(lossStd === expectedLossStd) + assert(gradStd === expectedGradientStd) + + // check should apply + val shouldApply2 = (i: Int) => i == 1 + val regFunApply = new L2Regularization(regParam, shouldApply2, None) + val (lossApply, gradApply) = regFunApply.calculate(coefficients) + assert(lossApply === 0.5 * regParam * coefficients(1) * coefficients(1)) + assert(gradApply === Array(0.0, coefficients(1) * regParam, 0.0)) + + // check with zero features standard + val featuresStdZero = Array(0.1, 0.0, 0.5) + val regFunStdZero = new L2Regularization(regParam, shouldApply, Some(featuresStdZero)) + val (_, gradStdZero) = regFunStdZero.calculate(coefficients) + assert(gradStdZero(1) == 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala new file mode 100644 index 0000000000000..cd5cebee5f7b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregatorSuite.TestAggregator +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD + +class RDDLossFunctionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = sc.parallelize(Seq( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + )) + } + + test("regularization") { + val coefficients = Vectors.dense(0.5, -0.1) + val regLossFun = new L2Regularization(0.1, (_: Int) => true, None) + val getAgg = (bvec: Broadcast[Vector]) => new TestAggregator(2)(bvec.value) + val lossNoReg = new RDDLossFunction(instances, getAgg, None) + val lossWithReg = new RDDLossFunction(instances, getAgg, Some(regLossFun)) + + val (loss1, grad1) = lossNoReg.calculate(coefficients.asBreeze.toDenseVector) + val (regLoss, regGrad) = regLossFun.calculate(coefficients.toArray) + val (loss2, grad2) = lossWithReg.calculate(coefficients.asBreeze.toDenseVector) + + BLAS.axpy(1.0, Vectors.fromBreeze(grad1), Vectors.dense(regGrad)) + assert(Vectors.dense(regGrad) ~== Vectors.fromBreeze(grad2) relTol 1e-5) + assert(loss1 + regLoss === loss2) + } + + test("empty RDD") { + val rdd = sc.parallelize(Seq.empty[Instance]) + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(rdd, getAgg, None) + withClue("cannot calculate cost for empty dataset") { + intercept[IllegalArgumentException]{ + lossFun.calculate(coefficients.asBreeze.toDenseVector) + } + } + } + + test("versus aggregating on an iterable") { + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(instances, getAgg, None) + val (loss, grad) = lossFun.calculate(coefficients.asBreeze.toDenseVector) + + // just map the aggregator over the instances array + val agg = new TestAggregator(2)(coefficients) + instances.collect().foreach(agg.add) + + assert(loss === agg.loss) + assert(Vectors.fromBreeze(grad) === agg.gradient) + } + +} diff --git a/pom.xml b/pom.xml index 0533a8dcf2e0a..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -83,6 +83,7 @@ common/sketch + common/kvstore common/network-common common/network-shuffle common/unsafe @@ -441,6 +442,11 @@ httpcore ${commons.httpcore.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + org.seleniumhq.selenium selenium-java @@ -588,6 +594,11 @@ metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-core + ${fasterxml.jackson.version} + com.fasterxml.jackson.core jackson-databind @@ -2521,6 +2532,7 @@ hadoop-2.7 2.7.3 + 2.7.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b5362ec1ae452..89b0c7a3ab7b0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -50,10 +50,10 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "tags", "sketch" + "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, @@ -310,7 +310,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010 + unsafe, tags, sqlKafka010, kvstore ).contains(x) } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 99abfcc556dff..8541403dfe2f1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1175,18 +1175,23 @@ def agg(self, *exprs): @since(2.0) def union(self, other): - """ Return a new :class:`DataFrame` containing union of rows in this - frame and another frame. + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by a distinct. + + Also as standard in SQL, this function resolves columns by position (not by name). """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) @since(1.3) def unionAll(self, other): - """ Return a new :class:`DataFrame` containing union of rows in this - frame and another frame. + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by a distinct. + + Also as standard in SQL, this function resolves columns by position (not by name). .. note:: Deprecated in 2.0, use union instead. """ diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 19e253394f1b2..56d697f359614 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -56,4 +56,11 @@ package object config { .stringConf .createOptional + private [spark] val DRIVER_LABELS = + ConfigBuilder("spark.mesos.driver.labels") + .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" + + "pairs should be separated by a colon, and commas used to list more than one." + + "Ex. key:value,key2:value2") + .stringConf + .createOptional } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1bc6f71860c3f..577f9a876b381 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,11 +30,13 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils + /** * Tracks the current state of a Mesos Task that runs a Spark driver. * @param driverDescription Submitted driver description from @@ -525,15 +527,17 @@ private[spark] class MesosClusterScheduler( offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") - val taskInfo = TaskInfo.newBuilder() + + TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) - taskInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) - taskInfo.build + .setLabels(MesosProtoUtils.mesosLabels(desc.conf.get(config.DRIVER_LABELS).getOrElse(""))) + .setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) + .build } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index ac7aec7b0a034..871685c6cccc0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -419,16 +419,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName(s"${sc.appName} $taskId") - - taskBuilder.addAllResources(resourcesToUse.asJava) - taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) - - val labelsBuilder = taskBuilder.getLabelsBuilder - val labels = buildMesosLabels().asJava - - labelsBuilder.addAllLabels(labels) - - taskBuilder.setLabels(labelsBuilder) + .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) + .addAllResources(resourcesToUse.asJava) + .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava @@ -444,21 +437,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } - private def buildMesosLabels(): List[Label] = { - taskLabels.split(",").flatMap(label => - label.split(":") match { - case Array(key, value) => - Some(Label.newBuilder() - .setKey(key) - .setValue(value) - .build()) - case _ => - logWarning(s"Unable to parse $label into a key:value label for the task.") - None - } - ).toList - } - /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala new file mode 100644 index 0000000000000..fea01c7068c9a --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConverters._ + +import org.apache.mesos.Protos + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +object MesosProtoUtils extends Logging { + + /** Parses a label string of the format specified in spark.mesos.task.labels. */ + def mesosLabels(labelsStr: String): Protos.Labels.Builder = { + val labels: Seq[Protos.Label] = if (labelsStr == "") { + Seq() + } else { + labelsStr.split("""(? + val parts = labelStr.split("""(? part.replaceAll("""\\,""", ",")) + .map(part => part.replaceAll("""\\:""", ":")) + + Protos.Label.newBuilder() + .setKey(cleanedParts(0)) + .setValue(cleanedParts(1)) + .build() + } + } + + Protos.Labels.newBuilder().addAllLabels(labels.asJava) + } +} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 32967b04cd346..0bb47906347d5 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -248,6 +248,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(networkInfos.get(0).getName == "test-network-name") } + test("supports spark.mesos.driver.labels") { + setScheduler() + + val mem = 1000 + val cpu = 1 + + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + "spark.mesos.driver.labels" -> "key:value"), + "s1", + new Date())) + + assert(response.success) + + val offer = Utils.createOffer("o1", "s1", mem, cpu) + scheduler.resourceOffers(driver, List(offer).asJava) + + val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") + val labels = launchedTasks.head.getLabels + assert(labels.getLabelsCount == 1) + assert(labels.getLabels(0).getKey == "key") + assert(labels.getLabels(0).getValue == "value") + } + test("can kill supervised drivers") { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 0418bfbaa5ed8..7cca5fedb31eb 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -532,29 +532,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getLabels.equals(taskLabels)) } - test("mesos ignored invalid labels and sets configurable labels on tasks") { - val taskLabelsString = "mesos:test,label:test,incorrect:label:here" - setBackend(Map( - "spark.mesos.task.labels" -> taskLabelsString - )) - - // Build up the labels - val taskLabels = Protos.Labels.newBuilder() - .addLabels(Protos.Label.newBuilder() - .setKey("mesos").setValue("test").build()) - .addLabels(Protos.Label.newBuilder() - .setKey("label").setValue("test").build()) - .build() - - val offers = List(Resources(backend.executorMemory(sc), 1)) - offerResources(offers) - val launchedTasks = verifyTaskLaunched(driver, "o1") - - val labels = launchedTasks.head.getLabels - - assert(launchedTasks.head.getLabels.equals(taskLabels)) - } - test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala new file mode 100644 index 0000000000000..36a4c1ab1ad25 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.spark.SparkFunSuite + +class MesosProtoUtilsSuite extends SparkFunSuite { + test("mesosLabels") { + val labels = MesosProtoUtils.mesosLabels("key:value") + assert(labels.getLabelsCount == 1) + val label = labels.getLabels(0) + assert(label.getKey == "key") + assert(label.getValue == "value") + + val labels2 = MesosProtoUtils.mesosLabels("key:value\\:value") + assert(labels2.getLabelsCount == 1) + val label2 = labels2.getLabels(0) + assert(label2.getKey == "key") + assert(label2.getValue == "value:value") + + val labels3 = MesosProtoUtils.mesosLabels("key:value,key2:value2") + assert(labels3.getLabelsCount == 2) + assert(labels3.getLabels(0).getKey == "key") + assert(labels3.getLabels(0).getValue == "value") + assert(labels3.getLabels(1).getKey == "key2") + assert(labels3.getLabels(1).getValue == "value2") + + val labels4 = MesosProtoUtils.mesosLabels("key:value\\,value") + assert(labels4.getLabelsCount == 1) + assert(labels4.getLabels(0).getKey == "key") + assert(labels4.getLabels(0).getValue == "value,value") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319ec3f..7683ee7074e7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -267,16 +267,11 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val array = - Invoke( - MapObjects( - p => deserializerFor(et, Some(p)), - getPath, - inferDataType(et)._1), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1, + customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 87130532c89bc..d580cf4d3391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyData = - Invoke( - MapObjects( - p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - val valueData = - Invoke( - MapObjects( - p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), - returnNullable = false), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[scala.collection.immutable.Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) + CollectObjectsToMap( + p => deserializerFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), + getPath, + mirror.runtimeClass(t.typeSymbol.asClass) + ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4bf360f42034b..877328164a8a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,51 +17,68 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier +import java.util.Locale +import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ -import org.apache.spark.sql.catalyst.util.StringKeyHashMap import org.apache.spark.sql.types._ /** * A catalog for looking up user defined functions, used by an [[Analyzer]]. * - * Note: The implementation should be thread-safe to allow concurrent access. + * Note: + * 1) The implementation should be thread-safe to allow concurrent access. + * 2) the database name is always case-sensitive here, callers are responsible to + * format the database name w.r.t. case-sensitive config. */ trait FunctionRegistry { - final def registerFunction(name: String, builder: FunctionBuilder): Unit = { - registerFunction(name, new ExpressionInfo(builder.getClass.getCanonicalName, name), builder) + final def registerFunction(name: FunctionIdentifier, builder: FunctionBuilder): Unit = { + val info = new ExpressionInfo( + builder.getClass.getCanonicalName, name.database.orNull, name.funcName) + registerFunction(name, info, builder) } - def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit + def registerFunction( + name: FunctionIdentifier, + info: ExpressionInfo, + builder: FunctionBuilder): Unit + + /* Create or replace a temporary function. */ + final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit = { + registerFunction( + FunctionIdentifier(name), + builder) + } @throws[AnalysisException]("If function does not exist") - def lookupFunction(name: String, children: Seq[Expression]): Expression + def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression /* List all of the registered function names. */ - def listFunction(): Seq[String] + def listFunction(): Seq[FunctionIdentifier] /* Get the class of the registered function by specified name. */ - def lookupFunction(name: String): Option[ExpressionInfo] + def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] /* Get the builder of the registered function by specified name. */ - def lookupFunctionBuilder(name: String): Option[FunctionBuilder] + def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] /** Drop a function and return whether the function existed. */ - def dropFunction(name: String): Boolean + def dropFunction(name: FunctionIdentifier): Boolean /** Checks if a function with a given name exists. */ - def functionExists(name: String): Boolean = lookupFunction(name).isDefined + def functionExists(name: FunctionIdentifier): Boolean = lookupFunction(name).isDefined /** Clear all registered functions. */ def clear(): Unit @@ -72,39 +89,47 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - protected val functionBuilders = - StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) + @GuardedBy("this") + private val functionBuilders = + new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)] + + // Resolution of the function name is always case insensitive, but the database name + // depends on the caller + private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database) + } override def registerFunction( - name: String, + name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = synchronized { - functionBuilders.put(name, (info, builder)) + functionBuilders.put(normalizeFuncName(name), (info, builder)) } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { val func = synchronized { - functionBuilders.get(name).map(_._2).getOrElse { + functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse { throw new AnalysisException(s"undefined function $name") } } func(children) } - override def listFunction(): Seq[String] = synchronized { - functionBuilders.iterator.map(_._1).toList.sorted + override def listFunction(): Seq[FunctionIdentifier] = synchronized { + functionBuilders.iterator.map(_._1).toList } - override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { - functionBuilders.get(name).map(_._1) + override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized { + functionBuilders.get(normalizeFuncName(name)).map(_._1) } - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized { - functionBuilders.get(name).map(_._2) + override def lookupFunctionBuilder( + name: FunctionIdentifier): Option[FunctionBuilder] = synchronized { + functionBuilders.get(normalizeFuncName(name)).map(_._2) } - override def dropFunction(name: String): Boolean = synchronized { - functionBuilders.remove(name).isDefined + override def dropFunction(name: FunctionIdentifier): Boolean = synchronized { + functionBuilders.remove(normalizeFuncName(name)).isDefined } override def clear(): Unit = synchronized { @@ -125,28 +150,28 @@ class SimpleFunctionRegistry extends FunctionRegistry { * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = { + override def registerFunction( + name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = { throw new UnsupportedOperationException } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - override def listFunction(): Seq[String] = { + override def listFunction(): Seq[FunctionIdentifier] = { throw new UnsupportedOperationException } - override def lookupFunction(name: String): Option[ExpressionInfo] = { + override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = { throw new UnsupportedOperationException } - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { + override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = { throw new UnsupportedOperationException } - override def dropFunction(name: String): Boolean = { + override def dropFunction(name: FunctionIdentifier): Boolean = { throw new UnsupportedOperationException } @@ -457,11 +482,13 @@ object FunctionRegistry { val builtin: SimpleFunctionRegistry = { val fr = new SimpleFunctionRegistry - expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } + expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) + } fr } - val functionSet: Set[String] = builtin.listFunction().toSet + val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet /** See usage above. */ private def expression[T <: Expression](name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 62a3482d9fac1..f068bce3e9b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -58,9 +58,9 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(plan, HintInfo(broadcast = true)) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(plan, HintInfo(broadcast = true)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. @@ -89,7 +89,7 @@ object ResolveHints { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(h.child, HintInfo(broadcast = true)) } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 974ef900e2eed..12ba5aedde026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,6 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit + def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + def getTable(db: String, table: String): CatalogTable def getTableOption(db: String, table: String): Option[CatalogTable] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8a5319bebe54e..9820522a230e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -312,6 +312,15 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = schema) } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = synchronized { + requireTableExists(db, table) + val origTable = catalog(db).tables(table).table + catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + } + override def getTable(db: String, table: String): CatalogTable = synchronized { requireTableExists(db, table) catalog(db).tables(table).table diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a78440df4f3e1..cf02da8993658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Locale +import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -125,14 +126,36 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } - /** - * A cache of qualified table names to table relation plans. - */ - val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() } + /** This method provides a way to get a cached plan. */ + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + tableRelationCache.get(t, c) + } + + /** This method provides a way to get a cached plan if the key exists. */ + def getCachedTable(key: QualifiedTableName): LogicalPlan = { + tableRelationCache.getIfPresent(key) + } + + /** This method provides a way to cache a plan. */ + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { + tableRelationCache.put(t, l) + } + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedTable(key: QualifiedTableName): Unit = { + tableRelationCache.invalidate(key) + } + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedTables(): Unit = { + tableRelationCache.invalidateAll() + } + /** * This method is used to make the given path qualified before we * store this path in the underlying external catalog. So, when a path @@ -353,6 +376,19 @@ class SessionCatalog( schema.fields.map(_.name).exists(conf.resolver(_, colName)) } + /** + * Alter Spark's statistics of an existing metastore table identified by the provided table + * identifier. + */ + def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + externalCatalog.alterTableStats(db, table, newStats) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. @@ -1006,13 +1042,12 @@ class SessionCatalog( requireDbExists(db) val identifier = name.copy(database = Some(db)) if (functionExists(identifier)) { - // TODO: registry should just take in FunctionIdentifier for type safety - if (functionRegistry.functionExists(identifier.unquotedString)) { + if (functionRegistry.functionExists(identifier)) { // If we have loaded this function into the FunctionRegistry, // also drop it from there. // For a permanent function, because we loaded it to the FunctionRegistry // when it's first used, we also need to drop it from the FunctionRegistry. - functionRegistry.dropFunction(identifier.unquotedString) + functionRegistry.dropFunction(identifier) } externalCatalog.dropFunction(db, name.funcName) } else if (!ignoreIfNotExists) { @@ -1038,7 +1073,7 @@ class SessionCatalog( def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) requireDbExists(db) - functionRegistry.functionExists(name.unquotedString) || + functionRegistry.functionExists(name) || externalCatalog.functionExists(db, name.funcName) } @@ -1072,20 +1107,20 @@ class SessionCatalog( ignoreIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier - if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + if (functionRegistry.functionExists(func) && !ignoreIfExists) { throw new AnalysisException(s"Function $func already exists") } val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) val builder = functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) - functionRegistry.registerFunction(func.unquotedString, info, builder) + functionRegistry.registerFunction(func, info, builder) } /** * Drop a temporary function. */ def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && !ignoreIfNotExists) { throw new NoSuchTempFunctionException(name) } } @@ -1100,8 +1135,8 @@ class SessionCatalog( // A temporary function is a function that has been registered in functionRegistry // without a database name, and is neither a built-in function nor a Hive function name.database.isEmpty && - functionRegistry.functionExists(name.funcName) && - !FunctionRegistry.builtin.functionExists(name.funcName) && + functionRegistry.functionExists(name) && + !FunctionRegistry.builtin.functionExists(name) && !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } @@ -1117,8 +1152,8 @@ class SessionCatalog( // TODO: just make function registry take in FunctionIdentifier instead of duplicating this val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) val qualifiedName = name.copy(database = database) - functionRegistry.lookupFunction(name.funcName) - .orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString)) + functionRegistry.lookupFunction(name) + .orElse(functionRegistry.lookupFunction(qualifiedName)) .getOrElse { val db = qualifiedName.database.get requireDbExists(db) @@ -1153,19 +1188,19 @@ class SessionCatalog( // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). - if (name.database.isEmpty && functionRegistry.functionExists(name.funcName)) { + if (name.database.isEmpty && functionRegistry.functionExists(name)) { // This function has been already loaded into the function registry. - return functionRegistry.lookupFunction(name.funcName, children) + return functionRegistry.lookupFunction(name, children) } // If the name itself is not qualified, add the current database to it. val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val qualifiedName = name.copy(database = Some(database)) - if (functionRegistry.functionExists(qualifiedName.unquotedString)) { + if (functionRegistry.functionExists(qualifiedName)) { // This function has been already loaded into the function registry. // Unlike the above block, we find this function by using the qualified name. - return functionRegistry.lookupFunction(qualifiedName.unquotedString, children) + return functionRegistry.lookupFunction(qualifiedName, children) } // The function has not been loaded to the function registry, which means @@ -1186,7 +1221,7 @@ class SessionCatalog( // At here, we preserve the input from the user. registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. - functionRegistry.lookupFunction(qualifiedName.unquotedString, children) + functionRegistry.lookupFunction(qualifiedName, children) } /** @@ -1206,8 +1241,8 @@ class SessionCatalog( requireDbExists(dbName) val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = - StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + val loadedFunctions = StringUtils + .filterPattern(functionRegistry.listFunction().map(_.unquotedString), pattern).map { f => // In functionRegistry, function names are stored as an unquoted format. Try(parser.parseFunctionIdentifier(f)) match { case Success(e) => e @@ -1220,7 +1255,7 @@ class SessionCatalog( // The session catalog caches some persistent functions in the FunctionRegistry // so there can be duplicates. functions.map { - case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") + case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") case f => (f, "USER") }.distinct } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ecf745c9..5bb0febc943f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -597,8 +598,8 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) => - // collection + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( @@ -609,6 +610,20 @@ case class MapObjects private( genValue => s"$builder.$$plus$$eq($genValue);", s"(${cls.getName}) $builder.result();" ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + genValue => s"$builder.add($genValue);", + s"$builder;" + ) case None => // array ( @@ -652,6 +667,173 @@ case class MapObjects private( } } +object CollectObjectsToMap { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjectsToMap case class. + * + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ + def apply( + keyFunction: Expression => Expression, + valueFunction: Expression => Expression, + inputData: Expression, + collClass: Class[_]): CollectObjectsToMap = { + val id = curId.getAndIncrement() + val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) + val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" + val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) + CollectObjectsToMap( + keyLoopValue, keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), + inputData, collClass) + } +} + +/** + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopValue the name of the loop variable that is used when iterating over the key + * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueLoopValue the name of the loop variable that is used when iterating over the value + * collection, and which is used as input for the `valueLambdaFunction` + * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ +case class CollectObjectsToMap private( + keyLoopValue: String, + keyLambdaFunction: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLambdaFunction: Expression, + inputData: Expression, + collClass: Class[_]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = + keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(dataType: DataType) = dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => dataType + } + + val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] + val keyElementJavaType = ctx.javaType(mapType.keyType) + ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = ctx.javaType(mapType.valueType) + ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") + + val getLength = s"${genInputData.value}.numElements()" + + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val getKeyArray = + s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getValueArray = + s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + int $dataLength = $getLength; + $constructBuilder + $getKeyArray + $getValueArray + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $appendToBuilder + + $loopIndex += 1; + } + + $getBuilderResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 4ed6728994193..bd144c9575c72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -278,7 +278,7 @@ class JacksonParser( // We cannot parse this token based on the given data type. So, we throw a // RuntimeException and this exception will be caught by `parse` method. throw new RuntimeException( - s"Failed to parse a value for data type $dataType (current token: $token).") + s"Failed to parse a value for data type ${dataType.catalogString} (current token: $token).") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 51f749a8bf857..66b8ca62e5e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -383,22 +383,27 @@ object LikeSimplification extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(input, Literal(pattern, StringType)) => - pattern.toString match { - case startsWith(prefix) if !prefix.endsWith("\\") => - StartsWith(input, Literal(prefix)) - case endsWith(postfix) => - EndsWith(input, Literal(postfix)) - // 'a%a' pattern is basically same with 'a%' && '%a'. - // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => - And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith("\\") => - Contains(input, Literal(infix)) - case equalTo(str) => - EqualTo(input, Literal(str)) - case _ => - Like(input, Literal.create(pattern, StringType)) + if (pattern == null) { + // If pattern is null, return null value directly, since "col like null" == null. + Literal(null, BooleanType) + } else { + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) + case _ => + Like(input, Literal.create(pattern, StringType)) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index d16fae56b3d4a..e49970df80457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -51,19 +51,17 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) } -case class HintInfo( - isBroadcastable: Option[Boolean] = None) { +case class HintInfo(broadcast: Boolean = false) { /** Must be called when computing stats for a join operator to reset hints. */ - def resetForJoin(): HintInfo = copy( - isBroadcastable = None - ) + def resetForJoin(): HintInfo = copy(broadcast = false) override def toString: String = { - if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) { - "none" - } else { - isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("") + val hints = scala.collection.mutable.ArrayBuffer.empty[String] + if (broadcast) { + hints += "broadcast" } + + if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index efb42292634ad..746c3e8950f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp - * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * and are stored internally as longs, which are capable of storing timestamps with microsecond * precision. */ object DateTimeUtils { @@ -399,13 +399,14 @@ object DateTimeUtils { digitsMilli += 1 } - if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { - return None + // We are truncating the nanosecond part, which results in loss of precision + while (digitsMilli > 6) { + segments(6) /= 10 + digitsMilli -= 1 } - // Instead of return None, we truncate the fractional seconds to prevent inserting NULL - if (segments(6) > 999999) { - segments(6) = segments(6).toString.take(6).toInt + if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { + return None } if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54bee02e44e43..3ea808926e10b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -352,7 +352,7 @@ object SQLConf { val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") - .intConf + .timeConf(TimeUnit.SECONDS) .createWithDefault(5 * 60) // This is only used for the thriftserver @@ -991,7 +991,7 @@ class SQLConf extends Serializable with Logging { def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) + def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064f93ebc..ff2414b174acb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary map types") { + val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + assert(mapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mapDeserializer = deserializerFor[Map[Int, Int]] + assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) + + import scala.collection.immutable.HashMap + val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + assert(hashMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) + + import scala.collection.mutable.{LinkedHashMap => LHMap} + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + assert(linkedHashMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 3d5148008c628..9782b5fb0d266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -36,17 +36,17 @@ class ResolveHintsSuite extends AnalysisTest { test("case-sensitive or insensitive parameters") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = true) checkAnalysis( @@ -58,28 +58,28 @@ class ResolveHintsSuite extends AnalysisTest { test("multiple broadcast hint aliases") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), - ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None), + Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))), - ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze, + ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))), + ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. @@ -104,7 +104,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true))) + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true)) .select('a).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d..2ac11598e63d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest { } test("coalesce casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1.0) - :: Literal(1) - :: Literal.create(1.0, FloatType) - :: Nil), - Coalesce(Cast(Literal(1.0), DoubleType) - :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) - :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1L) - :: Literal(1) - :: Literal(new java.math.BigDecimal("1000000000000000000000")) - :: Nil), - Coalesce(Cast(Literal(1L), DecimalType(22, 0)) - :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) - :: Nil)) + val rule = TypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(Cast(doubleLit, DoubleType), + Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), + Cast(intLit, FloatType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), + Cast(doubleLit, StringType), Cast(stringLit, StringType)))) } test("CreateArray casts") { @@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = TypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest { If(Literal.create(null, BooleanType), Literal(1), Literal(1))) ruleTest(rule, - If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, Cast(doubleLit, StringType))) ruleTest(rule, - If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(trueLit, timestampLit, stringLit), + If(trueLit, Cast(timestampLit, StringType), stringLit)) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 1759ac04c0033..557b0970b54e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -245,7 +245,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter table schema") { val catalog = newBasicCatalog() - val tbl1 = catalog.getTable("db2", "tbl1") val newSchema = StructType(Seq( StructField("col1", IntegerType), StructField("new_field_2", StringType), @@ -256,6 +255,16 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(newTbl1.schema == newSchema) } + test("alter table stats") { + val catalog = newBasicCatalog() + val oldTableStats = catalog.getTable("db2", "tbl1").stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats("db2", "tbl1", newStats) + val newTableStats = catalog.getTable("db2", "tbl1").stats + assert(newTableStats.get == newStats) + } + test("get table") { assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index be8903000a0d1..dce73b3635e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -448,6 +448,18 @@ abstract class SessionCatalogSuite extends PlanTest { } } + test("alter table stats") { + withBasicCatalog { catalog => + val tableId = TableIdentifier("tbl1", Some("db2")) + val oldTableStats = catalog.getTableMetadata(tableId).stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats(tableId, newStats) + val newTableStats = catalog.getTableMetadata(tableId).stats + assert(newTableStats.get == newStats) + } + } + test("alter table add columns") { withBasicCatalog { sessionCatalog => sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) @@ -1209,7 +1221,7 @@ abstract class SessionCatalogSuite extends PlanTest { assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1"))) // Returns false when the function is built-in or hive - assert(FunctionRegistry.builtin.functionExists("sum")) + assert(FunctionRegistry.builtin.functionExists(FunctionIdentifier("sum"))) assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum"))) assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6af0cde73538b..f4d5a4471d896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -23,6 +23,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts.implicitCast import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -223,6 +224,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def f: (Double) => Double = (x: Double) => 1 / math.tan(x) testUnary(Cot, f) checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType) + val nullLit = Literal.create(null, NullType) + val intNullLit = Literal.create(null, IntegerType) + val intLit = Literal.create(1, IntegerType) + checkEvaluation(checkDataTypeAndCast(Cot(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intLit)), 1 / math.tan(1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(-intLit)), 1 / math.tan(-1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(0)), 1 / math.tan(0), EmptyRow) } test("atan") { @@ -250,6 +259,11 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } + def checkDataTypeAndCast(expression: UnaryMathExpression): Expression = { + val expNew = implicitCast(expression.child, expression.inputTypes(0)).getOrElse(expression) + expression.withNewChildren(Seq(expNew)) + } + test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) @@ -262,12 +276,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Ceil(doublePi), 4L, EmptyRow) - checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) - checkEvaluation(Ceil(longLit), longLit, EmptyRow) - checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) - checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) - checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Ceil(doublePi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatPi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-doublePi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-floatPi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Ceil(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0.01)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-0.10)), 0L, EmptyRow) } test("floor") { @@ -282,12 +306,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Floor(doublePi), 3L, EmptyRow) - checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) - checkEvaluation(Floor(longLit), longLit, EmptyRow) - checkEvaluation(Floor(-doublePi), -4L, EmptyRow) - checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) - checkEvaluation(Floor(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Floor(doublePi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatPi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-doublePi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-floatPi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Floor(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0.01)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-0.10)), -1L, EmptyRow) } test("factorial") { @@ -541,10 +575,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intPi: Int = 314159265 val longPi: Long = 31415926535897932L val bdPi: BigDecimal = BigDecimal(31415927L, 7) + val floatPi: Float = 3.1415f val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593) + val floatResults: Seq[Float] = Seq(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 3.1f, 3.14f, + 3.141f, 3.1415f, 3.1415f, 3.1415f) + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ Seq.fill[Short](7)(31415) @@ -563,10 +601,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow) checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(floatPi, scale), floatResults(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 5064a1f63f83d..394c0a091e390 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.01f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) + + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index fdde89d079bc0..50398788c605c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, StringType} class LikeSimplificationSuite extends PlanTest { @@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("null pattern") { + val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d004d04569772..fef39a5b6a32f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -575,14 +575,6 @@ class PlanParserSuite extends PlanTest { ) ) - comparePlans( - parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"), - UnresolvedHint("HINT1", Seq($"a", - UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)), - table("t").select(star()) - ) - ) - comparePlans( parsePlan("SELECT /*+ HINT1(a, 5, 'a', b) */ * from t"), UnresolvedHint("HINT1", Seq($"a", Literal(5), Literal("a"), $"b"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 2afea6dd3d37c..833f5a71994f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -45,11 +45,11 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { expectedStatsCboOn = filterStatsCboOn, expectedStatsCboOff = filterStatsCboOff) - val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true))) + val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true)) checkStats( broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))), - expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true))) + expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast = true)), + expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast = true)) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 9799817494f15..c8cf16d937352 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -34,6 +34,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } + test("nanoseconds truncation") { + def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) { + val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime)) + assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly") + assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime) + } + + checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456") + checkStringToTimestamp("2015-01-02 00:00:00.100000009", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000050000", "2015-01-02 00:00:00.00005") + checkStringToTimestamp("2015-01-02 00:00:00.12005", "2015-01-02 00:00:00.12005") + checkStringToTimestamp("2015-01-02 00:00:00.100", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000456789", "2015-01-02 00:00:00.000456") + checkStringToTimestamp("1950-01-02 00:00:00.000456789", "1950-01-02 00:00:00.000456") + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 27d32b5dca431..0c5f3f22e31e8 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8abec85ee102a..d28ff7888d127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -131,7 +131,7 @@ private[sql] object Dataset { * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) - * .groupBy(department("name"), "gender") + * .groupBy(department("name"), people("gender")) * .agg(avg(people("salary")), max(people("age"))) * }}} * @@ -141,9 +141,9 @@ private[sql] object Dataset { * Dataset people = spark.read().parquet("..."); * Dataset department = spark.read().parquet("..."); * - * people.filter("age".gt(30)) - * .join(department, people.col("deptId").equalTo(department("id"))) - * .groupBy(department.col("name"), "gender") + * people.filter(people.col("age").gt(30)) + * .join(department, people.col("deptId").equalTo(department.col("id"))) + * .groupBy(department.col("name"), people.col("gender")) * .agg(avg(people.col("salary")), max(people.col("age"))); * }}} * @@ -1734,10 +1734,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 @@ -1747,10 +1748,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 17671ea8685b9..86574e2f71d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Maps + /** @since 2.3.0 */ + implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1bceac41b9de7..ad01b889429c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -61,7 +61,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | dataType: ${udf.dataType} """.stripMargin) - functionRegistry.registerFunction(name, udf.builder) + functionRegistry.createOrReplaceTempFunction(name, udf.builder) } /** @@ -75,7 +75,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) udaf } @@ -91,7 +91,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) udf } @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") } @@ -130,7 +130,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) - | functionRegistry.registerFunction( + | functionRegistry.createOrReplaceTempFunction( | name, | (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) |}""".stripMargin) @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -419,7 +419,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -432,7 +432,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -510,7 +510,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -521,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -532,7 +532,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -543,7 +543,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -554,7 +554,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -576,7 +576,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -587,7 +587,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -598,7 +598,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -609,7 +609,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -620,7 +620,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -631,7 +631,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -642,7 +642,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -653,7 +653,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -664,7 +664,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -675,7 +675,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -686,7 +686,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -697,7 +697,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -708,7 +708,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -730,7 +730,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -741,7 +741,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f13294c925e36..ea86f6e00fefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.isBroadcastable.getOrElse(false) || + plan.stats(conf).hints.broadcast || (plan.stats(conf).sizeInBytes >= 0 && plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f69a688555bbf..04c130314388a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} @@ -347,8 +347,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } override def inputRDDs(): Seq[RDD[InternalRow]] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) :: Nil + val rdd = if (start == end || (start < end ^ 0 < step)) { + new EmptyRDD[InternalRow](sqlContext.sparkContext) + } else { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + rdd :: Nil } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2de14c90ec757..2f273b63e8348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTable(tableMeta.copy(stats = Some(statistics))) + sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3183c7911b1fb..3c59b982c2dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -69,7 +69,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTable(tableMeta.copy(stats = newStats)) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 545082324f0d3..f39a3269efaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -160,7 +160,7 @@ case class DropFunctionCommand( throw new AnalysisException(s"Specifying a database in DROP TEMPORARY FUNCTION " + s"is not allowed: '${databaseName.get}'") } - if (FunctionRegistry.builtin.functionExists(functionName)) { + if (FunctionRegistry.builtin.functionExists(FunctionIdentifier(functionName))) { throw new AnalysisException(s"Cannot drop native function '$functionName'") } catalog.dropTempFunction(functionName, ifExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9ccd6792e5da4..b937a8a9f375b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -522,15 +522,15 @@ case class DescribeTableCommand( throw new AnalysisException( s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") } - describeSchema(catalog.lookupRelation(table).schema, result) + describeSchema(catalog.lookupRelation(table).schema, result, header = false) } else { val metadata = catalog.getTableMetadata(table) if (metadata.schema.isEmpty) { // In older version(prior to 2.1) of Spark, the table schema can be empty and should be // inferred at runtime. We should still support it. - describeSchema(sparkSession.table(metadata.identifier).schema, result) + describeSchema(sparkSession.table(metadata.identifier).schema, result, header = false) } else { - describeSchema(metadata.schema, result) + describeSchema(metadata.schema, result, header = false) } describePartitionInfo(metadata, result) @@ -550,7 +550,7 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - describeSchema(table.partitionSchema, buffer) + describeSchema(table.partitionSchema, buffer, header = true) } } @@ -601,8 +601,13 @@ case class DescribeTableCommand( table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } - private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + private def describeSchema( + schema: StructType, + buffer: ArrayBuffer[Row], + header: Boolean): Unit = { + if (header) { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + } schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 21d75a404911b..e05a8d5f02bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -215,9 +215,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) - val cache = sparkSession.sessionState.catalog.tableRelationCache + val catalogProxy = sparkSession.sessionState.catalog - val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val plan = catalogProxy.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 159aef220be15..43591a9ff524a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ @@ -65,7 +66,8 @@ class FailureSafeParser[IN]( case DropMalformedMode => Iterator.empty case FailFastMode => - throw e.cause + throw new SparkException("Malformed records are detected in record parsing. " + + s"Parse Mode: ${FailFastMode.name}.", e.cause) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index fb632cf2bb70e..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -21,6 +21,7 @@ import java.util.Comparator import com.fasterxml.jackson.core._ +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil @@ -61,7 +62,8 @@ private[sql] object JsonInferSchema { case DropMalformedMode => None case FailFastMode => - throw e + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}.", e) } } } @@ -231,8 +233,9 @@ private[sql] object JsonInferSchema { case FailFastMode => // If `other` is not struct type, consider it as malformed one and throws an exception. - throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + - s" but ${other.catalogString} was found.") + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " + + s"Struct types are expected, but `${other.catalogString}` was found.") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..e61a8eb628891 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = + (shortName(), RateSourceProvider.SCHEMA) + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 67ec1325b321e..8d2e1f32da059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1020,7 +1020,7 @@ object functions { */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc) + ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc) } /** @@ -1210,7 +1210,7 @@ object functions { /** * Creates a new struct column. * If the input column is a column in a `DataFrame`, or a derived column expression - * that is named (i.e. aliased), its name would be remained as the StructField's name, + * that is named (i.e. aliased), its name would be retained as the StructField's name, * otherwise, the newly generated StructField's name would be auto generated as * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc8b4..4ca3b6406a328 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1399,4 +1399,65 @@ public void testSerializeNull() { ds1.map((MapFunction) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List beans = Collections.singletonList(bean); + Dataset dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList arrayList; + private LinkedList linkedList; + private List list; + + public ArrayList getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList arrayList) { + this.arrayList = arrayList; + } + + public LinkedList getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList linkedList) { + this.linkedList = linkedList; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index d37bc9be7a2f4..d2f80c7a1ac79 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -65,11 +65,18 @@ select ceiling(0); select ceiling(1); select ceil(1234567890123456); select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); -- floor select floor(0); select floor(1); select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 -- mod select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 678a3f0f0a3c6..ba8bc936f0c79 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,7 +15,6 @@ DESC test_change -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a int b string c int @@ -35,7 +34,6 @@ DESC test_change -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a int b string c int @@ -55,7 +53,6 @@ DESC test_change -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a int b string c int @@ -94,7 +91,6 @@ DESC test_change -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a int b string c int @@ -129,7 +125,6 @@ DESC test_change -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -148,7 +143,6 @@ DESC test_change -- !query 14 schema struct -- !query 14 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -168,7 +162,6 @@ DESC test_change -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -193,7 +186,6 @@ DESC test_change -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -237,7 +229,6 @@ DESC test_change -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out index 1cc11c475bc40..eece00d603db4 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -15,7 +15,6 @@ DESC FORMATTED table_with_comment -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a string b int c string @@ -45,7 +44,6 @@ DESC FORMATTED table_with_comment -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a string b int c string @@ -84,7 +82,6 @@ DESC FORMATTED table_comment -- !query 6 schema struct -- !query 6 output -# col_name data_type comment a string b int @@ -111,7 +108,6 @@ DESC formatted table_comment -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a string b int @@ -139,7 +135,6 @@ DESC FORMATTED table_comment -- !query 10 schema struct -- !query 10 output -# col_name data_type comment a string b int diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index de10b29f3c65b..46d32bbc52247 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -54,7 +54,6 @@ DESCRIBE t -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a string b int c string @@ -70,7 +69,6 @@ DESC default.t -- !query 6 schema struct -- !query 6 output -# col_name data_type comment a string b int c string @@ -86,7 +84,6 @@ DESC TABLE t -- !query 7 schema struct -- !query 7 output -# col_name data_type comment a string b int c string @@ -102,7 +99,6 @@ DESC FORMATTED t -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a string b int c string @@ -132,7 +128,6 @@ DESC EXTENDED t -- !query 9 schema struct -- !query 9 output -# col_name data_type comment a string b int c string @@ -162,7 +157,6 @@ DESC t PARTITION (c='Us', d=1) -- !query 10 schema struct -- !query 10 output -# col_name data_type comment a string b int c string @@ -178,7 +172,6 @@ DESC EXTENDED t PARTITION (c='Us', d=1) -- !query 11 schema struct -- !query 11 output -# col_name data_type comment a string b int c string @@ -206,7 +199,6 @@ DESC FORMATTED t PARTITION (c='Us', d=1) -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a string b int c string @@ -268,7 +260,6 @@ DESC temp_v -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a string b int c string @@ -280,7 +271,6 @@ DESC TABLE temp_v -- !query 17 schema struct -- !query 17 output -# col_name data_type comment a string b int c string @@ -292,7 +282,6 @@ DESC FORMATTED temp_v -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a string b int c string @@ -304,7 +293,6 @@ DESC EXTENDED temp_v -- !query 19 schema struct -- !query 19 output -# col_name data_type comment a string b int c string @@ -316,7 +304,6 @@ DESC temp_Data_Source_View -- !query 20 schema struct -- !query 20 output -# col_name data_type comment intType int test comment test1 stringType string dateType date @@ -349,7 +336,6 @@ DESC v -- !query 22 schema struct -- !query 22 output -# col_name data_type comment a string b int c string @@ -361,7 +347,6 @@ DESC TABLE v -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a string b int c string @@ -373,7 +358,6 @@ DESC FORMATTED v -- !query 24 schema struct -- !query 24 output -# col_name data_type comment a string b int c string @@ -396,7 +380,6 @@ DESC EXTENDED v -- !query 25 schema struct -- !query 25 output -# col_name data_type comment a string b int c string diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index f5855de2038d1..57e8a612fab44 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 46 +-- Number of queries: 50 -- !query 0 @@ -351,32 +351,64 @@ struct -- !query 42 -select floor(0) +select ceil(0.01) -- !query 42 schema -struct +struct -- !query 42 output -0 +1 -- !query 43 -select floor(1) +select ceiling(-0.10) -- !query 43 schema -struct +struct -- !query 43 output -1 +0 -- !query 44 -select floor(1234567890123456) +select floor(0) -- !query 44 schema -struct +struct -- !query 44 output -1234567890123456 +0 -- !query 45 -select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) +select floor(1) -- !query 45 schema -struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> +struct -- !query 45 output -1 NULL 0 NULL NULL NULL +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output +1234567890123456 + + +-- !query 47 +select floor(0.01) +-- !query 47 schema +struct +-- !query 47 output +0 + + +-- !query 48 +select floor(-0.10) +-- !query 48 schema +struct +-- !query 48 output +-1 + + +-- !query 49 +select 1 > 0.00001 +-- !query 49 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 49 output +true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 7b495656b93d7..45afbd29d1907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -191,6 +191,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) } } + + test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + val start = java.lang.Long.MAX_VALUE - 3 + val end = java.lang.Long.MIN_VALUE + 2 + Seq("false", "true").foreach { value => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) + } + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 7e2949ab5aece..4126660b5d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + package object packageobject { case class PackageClass(value: Int) } @@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) } + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index b9871afd59e4f..539c63d3cb288 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -297,7 +297,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { } test("outer generator()") { - spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator()) + spark.sessionState.functionRegistry + .createOrReplaceTempFunction("empty_gen", _ => EmptyGenerator()) checkAnswer( sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), Row(1, null) :: Row(2, null) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 41e9e2c92ca8e..a7efcafa0166a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -109,7 +109,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5638c8eeda842..c01666770720c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution @@ -71,10 +72,10 @@ class SessionStateSuite extends SparkFunSuite } test("fork new session and inherit function registry and udf") { - val testFuncName1 = "strlenScala" - val testFuncName2 = "addone" + val testFuncName1 = FunctionIdentifier("strlenScala") + val testFuncName2 = FunctionIdentifier("addone") try { - activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + activeSession.udf.register(testFuncName1.funcName, (_: String).length + (_: Int)) val forkedSession = activeSession.cloneSession() // inheritance @@ -86,7 +87,7 @@ class SessionStateSuite extends SparkFunSuite // independence forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) - activeSession.udf.register(testFuncName2, (_: Int) + 1) + activeSession.udf.register(testFuncName2.funcName, (_: Int) + 1) assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) } finally { activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e66a60d7503f3..65472cda9c1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1036,24 +1036,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Corrupt records: FAILFAST mode") { - val schema = StructType( - StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - } - assert(exceptionOne.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionOne.contains( + "Malformed records are detected in schema inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read .option("mode", "FAILFAST") - .schema(schema) + .schema("a string") .json(corruptRecords) .collect() - } - assert(exceptionTwo.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionTwo.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1944,7 +1944,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("mode", "FAILFAST") .json(path) } - assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) + assert(exceptionOne.getMessage.contains("Malformed records are detected in schema " + + "inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1954,7 +1955,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Failed to parse a value")) + assert(exceptionTwo.getMessage.contains("Malformed records are detected in record " + + "parsing. Parse Mode: FAILFAST.")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 2a3d1cf0b298a..80ef4eb75ca53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,7 +21,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -36,7 +37,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..bdba536425a43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index bc641fd280a15..b2d568ce320e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -367,6 +367,7 @@ class CatalogSuite withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { // Try to find non existing functions. intercept[AnalysisException](spark.catalog.getFunction("fn1")) + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) intercept[AnalysisException](spark.catalog.getFunction("fn2")) intercept[AnalysisException](spark.catalog.getFunction(db, "fn2")) @@ -379,6 +380,8 @@ class CatalogSuite assert(fn1.name === "fn1") assert(fn1.database === null) assert(fn1.isTemporary) + // Find a temporary function with database + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) // Find a qualified function val fn2 = spark.catalog.getFunction(db, "fn2") @@ -455,6 +458,7 @@ class CatalogSuite // Find a temporary function assert(spark.catalog.functionExists("fn1")) + assert(!spark.catalog.functionExists(db, "fn1")) // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d1..2a4039cc5831a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -172,8 +172,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { * * @param isFatalError if this is a fatal error. If so, the error should also be caught by * UncaughtExceptionHandler. + * @param assertFailure a function to verify the error. */ case class ExpectFailure[T <: Throwable : ClassTag]( + assertFailure: Throwable => Unit = _ => {}, isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] override def toString(): String = @@ -455,6 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") streamThreadDeathCause = null } + ef.assertFailure(exception.getCause) } catch { case _: InterruptedException => case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index b49efa6890236..2986b7f1eecfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -78,9 +78,9 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { eventually(Timeout(streamingTimeout)) { require(!q2.isActive) require(q2.exception.isDefined) + assert(spark.streams.get(q2.id) === null) + assert(spark.streams.active.toSet === Set(q3)) } - assert(spark.streams.get(q2.id) === null) - assert(spark.streams.active.toSet === Set(q3)) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ff3784cab9e26..1d1074a2a7387 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation( return } else { setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 2e0fa1ef77f88..17589cf44b998 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -72,7 +72,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.getExecutionList + val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -142,7 +142,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = sessionList + val dataRows = sessionList.sortBy(_.startTimestamp).reverse val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 38b8605745752..5cd2fdf6437c2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -66,7 +66,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) val timeSinceStart = System.currentTimeMillis() - startTime.getTime
  • - Started at: {startTime.toString} + Started at: {formatDate(startTime)}
  • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -147,42 +147,6 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) {errorSummary}{details} } - /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { - val sessionList = listener.getSessionList - val numBatches = sessionList.size - val table = if (numBatches > 0) { - val dataRows = - sessionList.sortBy(_.startTimestamp).reverse.map ( session => - Seq( - session.userName, - session.ip, - session.sessionId, - formatDate(session.startTimestamp), - formatDate(session.finishTimestamp), - formatDurationOption(Some(session.totalTime)), - session.totalExecution.toString - ) - ).toSeq - val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", - "Total Execute") - Some(listingTable(headerRow, dataRows)) - } else { - None - } - - val content = -
    Session Statistics
    ++ -
    -
      - {table.getOrElse("No statistics have been generated yet.")} -
    -
    - - content - } - - /** * Returns a human-readable string representing a duration such as "5 second 35 ms" */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 918459fe7c246..7fcf06d66b5ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -527,7 +527,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * Alter a table whose name that matches the one specified in `tableDefinition`, - * assuming the table exists. + * assuming the table exists. This method does not change the properties for data source and + * statistics. * * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. @@ -538,30 +539,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, tableDefinition.identifier.table) verifyTableProperties(tableDefinition) - // convert table statistics to properties so that we can persist them through hive api - val withStatsProps = if (tableDefinition.stats.isDefined) { - val stats = tableDefinition.stats.get - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) - } - } - tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) - } else { - tableDefinition - } - if (tableDefinition.tableType == VIEW) { - client.alterTable(withStatsProps) + client.alterTable(tableDefinition) } else { - val oldTableDef = getRawTable(db, withStatsProps.identifier.table) + val oldTableDef = getRawTable(db, tableDefinition.identifier.table) val newStorage = if (DDLUtils.isHiveTable(tableDefinition)) { tableDefinition.storage @@ -611,12 +592,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat TABLE_PARTITION_PROVIDER -> TABLE_PARTITION_PROVIDER_FILESYSTEM } - // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, - // to retain the spark specific format if it is. Also add old data source properties to table - // properties, to retain the data source table format. - val oldDataSourceProps = oldTableDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) - val newTableProps = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp - val newDef = withStatsProps.copy( + // Add old data source properties to table properties, to retain the data source table format. + // Add old stats properties to table properties, to retain spark's stats. + // Set the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, + // to retain the spark specific format if it is. + val propsFromOldTable = oldTableDef.properties.filter { case (k, v) => + k.startsWith(DATASOURCE_PREFIX) || k.startsWith(STATISTICS_PREFIX) + } + val newTableProps = propsFromOldTable ++ tableDefinition.properties + partitionProviderProp + val newDef = tableDefinition.copy( storage = newStorage, schema = oldTableDef.schema, partitionColumnNames = oldTableDef.partitionColumnNames, @@ -647,6 +631,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + + // convert table statistics to properties so that we can persist them through hive client + var statsProperties: Map[String, String] = + Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + if (stats.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } + } + + val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX)) + val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties) + client.alterTable(updatedTable) + } + override def getTable(db: String, table: String): CatalogTable = withClient { restoreTableMetadata(getRawTable(db, table)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9dd8279efc1f4..ff5afc8e3ce05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState - private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + private def catalogProxy = sparkSession.sessionState.catalog import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ @@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val key = QualifiedTableName( table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) - tableRelationCache.getIfPresent(key) + catalogProxy.getCachedTable(key) } private def getCached( @@ -71,7 +71,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log expectedFileFormat: Class[_ <: FileFormat], partitionSchema: Option[StructType]): Option[LogicalRelation] = { - tableRelationCache.getIfPresent(tableIdentifier) match { + catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass @@ -92,21 +92,21 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(logical) } else { // If the cached relation is not updated, we invalidate it right away. - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case _ => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case other => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a $other from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } } @@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormat = fileFormat, options = options)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } @@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log className = fileType).resolveRelation(), table = updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 6227e780c0409..da87f0218e3ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -129,7 +129,7 @@ private[sql] class HiveSessionCatalog( Try(super.lookupFunction(funcName, children)) match { case Success(expr) => expr case Failure(error) => - if (functionRegistry.functionExists(funcName.unquotedString)) { + if (functionRegistry.functionExists(funcName)) { // If the function actually exists in functionRegistry, it means that there is an // error when we create the Expression using the given children. // We need to throw the original exception. @@ -163,7 +163,7 @@ private[sql] class HiveSessionCatalog( // Put this Hive built-in function to our function registry. registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. - functionRegistry.lookupFunction(functionName, children) + functionRegistry.lookupFunction(functionIdentifier, children) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index b3a06045b5fd4..d271acc63de08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite override def afterEach(): Unit = { super.afterEach() - spark.sessionState.catalog.tableRelationCache.invalidateAll() + spark.sessionState.catalog.invalidateAllCachedTables() FileStatusCache.resetForTesting() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 5d52f8baa3b94..001bbc230ff18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,7 +25,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -267,7 +267,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("get statistics when not analyzed in both Hive and Spark") { + test("get statistics when not analyzed in Hive or Spark") { val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) @@ -313,60 +313,70 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("alter table SET TBLPROPERTIES after analyze table") { - Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" - withTable(tabName) { - createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('foo' = 'a')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) - - val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") - - val totalSize = extractStatsPropValues(describeResult, "totalSize") - assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + test("alter table should not have the side effect to store statistics in Spark side") { + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } - // ALTER TABLE SET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API - val numRows = extractStatsPropValues(describeResult, "numRows") - assert(numRows.isDefined && numRows.get == -1, "numRows is lost") - val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") - assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") - } + val table = "alter_table_side_effect" + withTable(table) { + sql(s"CREATE TABLE $table (i string, j string)") + sql(s"INSERT INTO TABLE $table SELECT 'a', 'b'") + val catalogTable1 = getCatalogTable(table) + val hiveSize1 = BigInt(catalogTable1.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + + sql(s"ALTER TABLE $table SET TBLPROPERTIES ('prop1' = 'a')") + + sql(s"INSERT INTO TABLE $table SELECT 'c', 'd'") + val catalogTable2 = getCatalogTable(table) + val hiveSize2 = BigInt(catalogTable2.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + // After insertion, Hive's stats should be changed. + assert(hiveSize2 > hiveSize1) + // We haven't generate stats in Spark, so we should still use Hive's stats here. + assert(catalogTable2.stats.get.sizeInBytes == hiveSize2) } } - test("alter table UNSET TBLPROPERTIES after analyze table") { + private def testAlterTableProperties(tabName: String, alterTablePropCmd: String): Unit = { Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName UNSET TBLPROPERTIES ('prop1')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + + // Run ALTER TABLE command + sql(alterTablePropCmd) val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") val totalSize = extractStatsPropValues(describeResult, "totalSize") assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") - // ALTER TABLE UNSET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API + // ALTER TABLE SET/UNSET TBLPROPERTIES invalidates some Hive specific statistics, but not + // Spark specific statistics. This is triggered by the Hive alterTable API. val numRows = extractStatsPropValues(describeResult, "numRows") assert(numRows.isDefined && numRows.get == -1, "numRows is lost") val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + + if (analyzedBySpark) { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } else { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) + } } } } + test("alter table SET TBLPROPERTIES after analyze table") { + testAlterTableProperties("set_prop_table", + "ALTER TABLE set_prop_table SET TBLPROPERTIES ('foo' = 'a')") + } + + test("alter table UNSET TBLPROPERTIES after analyze table") { + testAlterTableProperties("unset_prop_table", + "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index ab931b94987d3..aca964907d4cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -806,7 +806,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil + Row("a", "int", "test") :: Nil ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 8fcbad58350f4..cae338c0ab0ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -194,7 +194,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(sql("SELECT percentile_approx(100.0D, array(0.9D, 0.9D)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) - } + } test("UDFIntegerToString") { val testData = spark.sparkContext.parallelize( @@ -592,6 +592,17 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } + test("Temp function has dots in the names") { + withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) { + sql(s"CREATE FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + // temp function containing dots in the name + spark.udf.register("default.test_avg", () => { Math.random() + 2}) + assert(sql("SELECT `default.test_avg`()").head().getDouble(0) >= 2.0) + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + } + } + test("Call the function registered in the not-current database") { Seq("true", "false").foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index da7a0645dbbeb..a949e5e829e14 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().map(_.unquotedString) val allFunctions = sql("SHOW functions").collect().map(r => r(0)) allBuiltinFunctions.foreach { f => assert(allFunctions.contains(f)) diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 2803cad8095dd..00c59728748f6 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -56,7 +56,7 @@ public abstract class WriteAheadLog { public abstract void clean(long threshTime, boolean waitForCompletion); /** - * Close this log and release any resources. + * Close this log and release any resources. It must be idempotent. */ public abstract void close(); } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index bd7ab0b9bf5eb..6f130c803f310 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -165,11 +165,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (isTrackerStarted) { - // First, stop the receivers - trackerState = Stopping + val isStarted: Boolean = isTrackerStarted + trackerState = Stopping + if (isStarted) { if (!skipReceiverLaunch) { - // Send the stop signal to all the receivers + // First, stop the receivers. Send the stop signal to all the receivers endpoint.askSync[Boolean](StopAllReceivers) // Wait for the Spark job that runs the receivers to be over @@ -194,17 +194,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped - } else if (isTrackerInitialized) { - trackerState = Stopping - // `ReceivedBlockTracker` is open when this instance is created. We should - // close this even if this `ReceiverTracker` is not started. - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped } + + // `ReceivedBlockTracker` is open when this instance is created. We should + // close this even if this `ReceiverTracker` is not started. + receivedBlockTracker.stop() + logInfo("ReceiverTracker stopped") + trackerState = Stopped } /** Allocate all unallocated blocks to the given batch. */ @@ -453,9 +449,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false endpoint.send(StartAllReceivers(receivers)) } - /** Check if tracker has been marked for initiated */ - private def isTrackerInitialized: Boolean = trackerState == Initialized - /** Check if tracker has been marked for starting */ private def isTrackerStarted: Boolean = trackerState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 35f0166ed0cf2..e522bc62d5cac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private val walWriteQueue = new LinkedBlockingQueue[Record]() // Whether the writer thread is active - @volatile private var active: Boolean = true + private val active: AtomicBoolean = new AtomicBoolean(true) private val buffer = new ArrayBuffer[Record]() private val batchedWriterThread = startBatchedWriterThread() @@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { val promise = Promise[WriteAheadLogRecordHandle]() val putSuccessfully = synchronized { - if (active) { + if (active.get()) { walWriteQueue.offer(Record(byteBuffer, time, promise)) true } else { @@ -121,9 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp */ override def close(): Unit = { logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") - synchronized { - active = false - } + if (!active.getAndSet(false)) return batchedWriterThread.interrupt() batchedWriterThread.join() while (!walWriteQueue.isEmpty) { @@ -138,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private def startBatchedWriterThread(): Thread = { val thread = new Thread(new Runnable { override def run(): Unit = { - while (active) { + while (active.get()) { try { flushRecords() } catch { @@ -166,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp } try { var segment: WriteAheadLogRecordHandle = null - if (buffer.length > 0) { + if (buffer.nonEmpty) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") // threads may not be able to add items in order by time val sortedByTime = buffer.sortBy(_.time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 1e5f18797e152..d6e15cfdd2723 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -205,10 +205,12 @@ private[streaming] class FileBasedWriteAheadLog( /** Stop the manager, close any open log writer */ def close(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() + if (!executionContext.isShutdown) { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() } - executionContext.shutdown() logInfo("Stopped write ahead log manager") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 3c4a2716caf90..fe65353b9d502 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -50,7 +50,6 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) extends SparkFunSuite with BeforeAndAfter with Matchers - with LocalSparkContext with Logging { import WriteAheadLogBasedBlockHandler._ @@ -89,10 +88,9 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - sc = new SparkContext("local", "test", conf) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index df122ac090c3e..c206d3169d77e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -57,6 +57,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } finally { tracker.stop(false) + // Make sure it is idempotent. + tracker.stop(false) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4bec52b9fe4fe..ede15399f0e2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -140,6 +140,8 @@ abstract class CommonWriteAheadLogTests( } } writeAheadLog.close() + // Make sure it is idempotent. + writeAheadLog.close() } test(testPrefix + "handling file errors while reading rotating logs") {